Compare commits

...

36 Commits

Author SHA1 Message Date
Rostislav Dugin
920c98e229 Merge pull request #397 from databasus/develop
FIX (migrations): Fix version of migrations tool goose
2026-02-22 23:43:55 +03:00
Rostislav Dugin
2a19a96aae FIX (migrations): Fix version of migrations tool goose 2026-02-22 23:43:23 +03:00
Rostislav Dugin
75aa2108d9 Merge pull request #396 from databasus/develop
FIX (email): Use current OS hostname instead of default localhost
2026-02-22 23:33:28 +03:00
Rostislav Dugin
0a0040839e FIX (email): Use current OS hostname instead of default localhost 2026-02-22 23:31:25 +03:00
Rostislav Dugin
ff4f795ece Merge pull request #394 from databasus/develop
FIX (nas): Add NAS share validation
2026-02-22 16:05:38 +03:00
Rostislav Dugin
dc05502580 FIX (nas): Add NAS share validation 2026-02-22 15:56:30 +03:00
Rostislav Dugin
1ca38f5583 Merge pull request #390 from databasus/develop
FEATURE (templates): Add PR template
2026-02-21 15:58:21 +03:00
Rostislav Dugin
40b3ff61c7 FEATURE (templates): Add PR template 2026-02-21 15:53:01 +03:00
Rostislav Dugin
e1b245a573 Merge pull request #389 from databasus/develop
Develop
2026-02-21 14:57:56 +03:00
Rostislav Dugin
fdf29b71f2 FIX (mongodb): Fix direct connection string parsing 2026-02-21 14:56:48 +03:00
Rostislav Dugin
49da981c21 Merge pull request #388 from databasus/main
Merge main into dev
2026-02-21 14:53:31 +03:00
Rostislav Dugin
9d611d3559 REFACTOR (mongodb): Refactor direct connection PR 2026-02-21 14:43:47 +03:00
ujstor
22cab53dab feature/mongodb-directConnection (#377)
FEATURE (mongodb): Add direct connection
2026-02-21 14:10:28 +03:00
Rostislav Dugin
d761c4156c Merge pull request #385 from databasus/develop
FIX (readme): Fix README typo
2026-02-20 17:17:45 +03:00
Rostislav Dugin
cbb8b82711 FIX (readme): Fix README typo 2026-02-20 17:01:44 +03:00
Rostislav Dugin
8e3d1e5bff Merge pull request #384 from databasus/develop
FIX (backups): Do not reload backups if request already in progress
2026-02-20 15:04:19 +03:00
Rostislav Dugin
349e7f0ee8 FIX (backups): Do not reload backups if request already in progress 2026-02-20 14:43:07 +03:00
Rostislav Dugin
3a274e135b Merge pull request #383 from databasus/develop
FEATURE (backups): Add GFS retention policy
2026-02-20 14:33:29 +03:00
Rostislav Dugin
61e937bc2a FEATURE (backups): Add GFS retention policy 2026-02-20 14:31:56 +03:00
Rostislav Dugin
f67919fe1a Merge pull request #374 from databasus/develop
FIX (backups): Fix backup download and clean up
2026-02-18 12:53:10 +03:00
Rostislav Dugin
91ee5966d8 FIX (backups): Fix backup download and clean up 2026-02-18 12:52:35 +03:00
Rostislav Dugin
d77d7d69a3 Merge pull request #371 from databasus/develop
FEATURE (backups): Add metadata alongsize with backup files itself to…
2026-02-17 19:54:53 +03:00
Rostislav Dugin
fc88b730d5 FEATURE (backups): Add metadata alongsize with backup files itself to make them recovarable without Databasus 2026-02-17 19:52:08 +03:00
Rostislav Dugin
1f1d80245f Merge pull request #368 from databasus/develop
FIX (restores): Increase restore timeout to 23 hours instead of 1 hour
2026-02-17 14:56:58 +03:00
Rostislav Dugin
16a29cf458 FIX (restores): Increase restore timeout to 23 hours instead of 1 hour 2026-02-17 14:56:25 +03:00
Rostislav Dugin
43e04500ac Merge pull request #367 from databasus/develop
FEATURE (backups): Add meaningful names for backups
2026-02-17 14:50:21 +03:00
Rostislav Dugin
cee3022f85 FEATURE (backups): Add meaningful names for backups 2026-02-17 14:49:33 +03:00
Rostislav Dugin
f46d92c480 Merge pull request #365 from databasus/develop
FIX (audit logs): Get rid of IDs in audit logs and improve naming log…
2026-02-15 01:10:54 +03:00
Rostislav Dugin
10677238d7 FIX (audit logs): Get rid of IDs in audit logs and improve naming logging 2026-02-15 01:06:39 +03:00
Rostislav Dugin
2553203fcf Merge pull request #363 from databasus/develop
FIX (sign up): Return authorization token on sign up to avoid 2-step …
2026-02-15 00:09:00 +03:00
Rostislav Dugin
7b05bd8000 FIX (sign up): Return authorization token on sign up to avoid 2-step sign up 2026-02-15 00:08:01 +03:00
Rostislav Dugin
8d45728f73 Merge pull request #362 from databasus/develop
FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign …
2026-02-14 23:19:12 +03:00
Rostislav Dugin
c70ad82c95 FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign up \ password reset 2026-02-14 23:11:36 +03:00
Rostislav Dugin
e4bc34d319 Merge pull request #361 from databasus/develop
Develop
2026-02-13 16:57:25 +03:00
Rostislav Dugin
257ae85da7 FIX (postgres): Fix read-only issue when user cannot access tables and partitions created after user creation 2026-02-13 16:56:56 +03:00
Rostislav Dugin
b42c820bb2 FIX (mariadb): Fix events exclusion 2026-02-13 16:21:48 +03:00
99 changed files with 4533 additions and 1064 deletions

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug Report
about: Report a bug or unexpected behavior in Databasus
labels: bug
---
## Databasus version
<!-- e.g. 1.4.2 -->
## Operating system and architecture
<!-- e.g. Ubuntu 22.04 x64, macOS 14 ARM, Windows 11 x64 -->
## Describe the bug (please write manually, do not ask AI to summarize)
**What happened:**
**What I expected:**
## Steps to reproduce
1.
2.
3.
## Have you asked AI how to solve the issue?
<!-- Using AI to diagnose issues before filing a bug report helps narrow down root causes. -->
- [ ] Claude Sonnet 4.6 or newer
- [ ] ChatGPT 5.2 or newer
- [ ] No
## Additional context / logs
<!-- Screenshots, error messages, relevant log output, etc. -->

View File

@@ -407,7 +407,7 @@ jobs:
- name: Run database migrations
run: |
cd backend
go install github.com/pressly/goose/v3/cmd/goose@latest
go install github.com/pressly/goose/v3/cmd/goose@v3.24.3
goose up
- name: Run Go tests

View File

@@ -268,7 +268,8 @@ window.__RUNTIME_CONFIG__ = {
IS_CLOUD: '\${IS_CLOUD:-false}',
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED'
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
};
JSEOF

View File

@@ -11,7 +11,7 @@
[![MongoDB](https://img.shields.io/badge/MongoDB-47A248?logo=mongodb&logoColor=white)](https://www.mongodb.com/)
<br />
[![Apache 2.0 License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Docker Pulls](https://img.shields.io/docker/pulls/databasus/databasus?color=brightgreen)](https://hub.docker.com/r/databasus/databasus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/databasus/databasus)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/databasus/databasus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/databasus/databasus)
@@ -31,8 +31,6 @@
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
</div>
---
@@ -52,6 +50,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗑️ **Retention policies**
- **Time period**: Keep backups for a fixed duration (e.g., 7 days, 3 months, 1 year)
- **Count**: Keep a fixed number of the most recent backups (e.g., last 30)
- **GFS (Grandfather-Father-Son)**: Layered retention — keep hourly, daily, weekly, monthly and yearly backups independently for fine-grained long-term history (enterprises requirement)
- **Size limits**: Set per-backup and total storage size caps to control storage usage
### 🗄️ **Multiple storage destinations** <a href="https://databasus.com/storages">(view supported)</a>
- **Local storage**: Keep backups on your VPS/server
@@ -71,6 +76,8 @@
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
- **Read-only user**: Databasus uses a read-only user by default for backups and never stores anything that can modify your data
It is also important for Databasus that you are able to decrypt and restore backups from storages (local, S3, etc.) without Databasus itself. To do so, read our guide on [how to recover directly from storage](https://databasus.com/how-to-recover-without-databasus). We avoid "vendor lock-in" even to open source tool!
### 👥 **Suitable for teams** <a href="https://databasus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
@@ -220,8 +227,9 @@ For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
4. **Set database connection**: Enter your database credentials and connection details
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
7. **Save and start**: Databasus will validate settings and begin the backup schedule
6. **Configure retention policy**: Choose time period, count or GFS to control how long backups are kept
7. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
8. **Save and start**: Databasus will validate settings and begin the backup schedule
### 🔑 Resetting password <a href="https://databasus.com/password">(docs)</a>
@@ -233,56 +241,22 @@ docker exec -it databasus ./main --new-password="YourNewSecurePassword123" --ema
Replace `admin` with the actual email address of the user whose password you want to reset.
### 💾 Backuping Databasus itself
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
---
## 📝 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
---
## 🤝 Contributing
Contributions are welcome! Read the <a href="https://databasus.com/contribute">contributing guide</a> for more details, priorities and rules. If you want to contribute but don't know where to start, message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
--
## 📖 Migration guide
Databasus is the new name for Postgresus. You can stay with latest version of Postgresus if you wish. If you want to migrate - follow installation steps for Databasus itself.
Just renaming an image is not enough as Postgresus and Databasus use different data folders and internal database naming.
You can put a new Databasus image with updated volume near the old Postgresus and run it (stop Postgresus before):
```
services:
databasus:
container_name: databasus
image: databasus/databasus:latest
ports:
- "4005:4005"
volumes:
- ./databasus-data:/databasus-data
restart: unless-stopped
```
Then manually move databases from Postgresus to Databasus.
### Why was Postgresus renamed to Databasus?
Databasus has been developed since 2023. It was internal tool to backup production and home projects databases. In start of 2025 it was released as open source project on GitHub. By the end of 2025 it became popular and the time for renaming has come in December 2025.
It was an important step for the project to grow. Actually, there are a couple of reasons:
1. Postgresus is no longer a little tool that just adds UI for pg_dump for little projects. It became a tool both for individual users, DevOps, DBAs, teams, companies and even large enterprises. Tens of thousands of users use Postgresus every day. Postgresus grew into a reliable backup management tool. Initial positioning is no longer suitable: the project is not just a UI wrapper, it's a solid backup management system now (despite it's still easy to use).
2. New databases are supported: although the primary focus is PostgreSQL (with 100% support in the most efficient way) and always will be, Databasus added support for MySQL, MariaDB and MongoDB. Later more databases will be supported.
3. Trademark issue: "postgres" is a trademark of PostgreSQL Inc. and cannot be used in the project name. So for safety and legal reasons, we had to rename the project.
## AI disclaimer
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.

View File

@@ -11,6 +11,9 @@ VICTORIA_LOGS_PASSWORD=devpassword
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# cloudflare turnstile
CLOUDFLARE_TURNSTILE_SITE_KEY=
CLOUDFLARE_TURNSTILE_SECRET_KEY=
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

View File

@@ -104,6 +104,10 @@ type EnvVariables struct {
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// Cloudflare Turnstile
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`

View File

@@ -1,7 +1,9 @@
package backuping
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
@@ -196,7 +198,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backup,
backupConfig,
database,
storage,
@@ -263,7 +265,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.FileName); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
@@ -311,6 +313,13 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backupMetadata.BackupID = backup.ID
if err := backupMetadata.Validate(); err != nil {
n.logger.Error("Failed to validate backup metadata", "error", err)
return
}
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
@@ -321,6 +330,39 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
return
}
// Save metadata file to storage
if backupMetadata != nil {
metadataJSON, err := json.Marshal(backupMetadata)
if err != nil {
n.logger.Error("Failed to marshal backup metadata to JSON",
"backupId", backup.ID,
"error", err,
)
} else {
metadataReader := bytes.NewReader(metadataJSON)
metadataFileName := backup.FileName + ".metadata"
if err := storage.SaveFile(
context.Background(),
n.fieldEncryptor,
n.logger,
metadataFileName,
metadataReader,
); err != nil {
n.logger.Error("Failed to save backup metadata file to storage",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
)
} else {
n.logger.Info("Backup metadata file saved successfully",
"backupId", backup.ID,
"fileName", metadataFileName,
)
}
}
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {

View File

@@ -18,7 +18,8 @@ import (
)
const (
cleanerTickerInterval = 1 * time.Minute
cleanerTickerInterval = 1 * time.Minute
recentBackupGracePeriod = 60 * time.Minute
)
type BackupCleaner struct {
@@ -51,8 +52,8 @@ func (c *BackupCleaner) Run(ctx context.Context) {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanOldBackups(); err != nil {
c.logger.Error("Failed to clean old backups", "error", err)
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
@@ -79,7 +80,7 @@ func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.ID)
err = storage.DeleteFile(c.fieldEncryptor, backup.FileName)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
@@ -88,6 +89,11 @@ func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
c.logger.Error("Failed to delete backup file", "error", err)
}
metadataFileName := backup.FileName + ".metadata"
if err := storage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error("Failed to delete backup metadata file", "error", err)
}
return c.backupRepository.DeleteByID(backup.ID)
}
@@ -95,49 +101,30 @@ func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemo
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanOldBackups() error {
func (c *BackupCleaner) cleanByRetentionPolicy() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
var cleanErr error
if backupStorePeriod == period.PeriodForever {
continue
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
if cleanErr != nil {
c.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
}
}
@@ -174,6 +161,158 @@ func (c *BackupCleaner) cleanExceededBackups() error {
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
if backupConfig.RetentionTimePeriod == period.PeriodForever {
return nil
}
storeDuration := backupConfig.RetentionTimePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
return fmt.Errorf(
"failed to find old backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
for _, backup := range oldBackups {
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
backupConfig.DatabaseID,
backups_core.BackupStatusCompleted,
)
if err != nil {
return fmt.Errorf(
"failed to find completed backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
// completedBackups are ordered newest first; delete everything beyond position RetentionCount
if len(completedBackups) <= backupConfig.RetentionCount {
return nil
}
toDelete := completedBackups[backupConfig.RetentionCount:]
for _, backup := range toDelete {
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
return nil
}
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
backupConfig.DatabaseID,
backups_core.BackupStatusCompleted,
)
if err != nil {
return fmt.Errorf(
"failed to find completed backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
keepSet := buildGFSKeepSet(
completedBackups,
backupConfig.RetentionGfsHours,
backupConfig.RetentionGfsDays,
backupConfig.RetentionGfsWeeks,
backupConfig.RetentionGfsMonths,
backupConfig.RetentionGfsYears,
)
for _, backup := range completedBackups {
if keepSet[backup.ID] {
continue
}
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
databaseID uuid.UUID,
limitperDbMB int64,
@@ -210,6 +349,21 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
@@ -240,3 +394,68 @@ func (c *BackupCleaner) cleanExceededBackupsForDatabase(
return nil
}
func isRecentBackup(backup *backups_core.Backup) bool {
return time.Since(backup.CreatedAt) < recentBackupGracePeriod
}
// buildGFSKeepSet determines which backups to retain under the GFS rotation scheme.
// Backups must be sorted newest-first. A backup can fill multiple slots simultaneously
// (e.g. the newest backup of a year also fills the monthly, weekly, daily, and hourly slot).
func buildGFSKeepSet(
backups []*backups_core.Backup,
hours, days, weeks, months, years int,
) map[uuid.UUID]bool {
keep := make(map[uuid.UUID]bool)
hoursSeen := make(map[string]bool)
daysSeen := make(map[string]bool)
weeksSeen := make(map[string]bool)
monthsSeen := make(map[string]bool)
yearsSeen := make(map[string]bool)
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
for _, backup := range backups {
t := backup.CreatedAt
hourKey := t.Format("2006-01-02-15")
dayKey := t.Format("2006-01-02")
weekYear, week := t.ISOWeek()
weekKey := fmt.Sprintf("%d-%02d", weekYear, week)
monthKey := t.Format("2006-01")
yearKey := t.Format("2006")
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] {
keep[backup.ID] = true
hoursSeen[hourKey] = true
hoursKept++
}
if days > 0 && daysKept < days && !daysSeen[dayKey] {
keep[backup.ID] = true
daysSeen[dayKey] = true
daysKept++
}
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] {
keep[backup.ID] = true
weeksSeen[weekKey] = true
weeksKept++
}
if months > 0 && monthsKept < months && !monthsSeen[monthKey] {
keep[backup.ID] = true
monthsSeen[monthKey] = true
monthsKept++
}
if years > 0 && yearsKept < years && !yearsSeen[yearKey] {
keep[backup.ID] = true
yearsSeen[yearKey] = true
yearsKept++
}
}
return keep
}

View File

@@ -25,24 +25,24 @@ var backupRepository = &backups_core.BackupRepository{}
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var backupCleaner = &BackupCleaner{
backupRepository: backupRepository,
storageService: storages.GetStorageService(),
backupConfigService: backups_config.GetBackupConfigService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
var backupNodesRegistry = &BackupNodesRegistry{
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
sync.Once{},
atomic.Bool{},
}
func getNodeID() uuid.UUID {
@@ -50,34 +50,35 @@ func getNodeID() uuid.UUID {
}
var backuperNode = &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: getNodeID(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: backuperNode,
runOnce: sync.Once{},
hasRun: atomic.Bool{},
backupRepository,
backups_config.GetBackupConfigService(),
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
sync.Once{},
atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {

View File

@@ -7,6 +7,7 @@ import (
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
@@ -32,7 +33,7 @@ type CreateFailedBackupUsecase struct{}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -46,7 +47,7 @@ type CreateSuccessBackupUsecase struct{}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -65,7 +66,7 @@ type CreateLargeBackupUsecase struct{}
func (uc *CreateLargeBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -84,7 +85,7 @@ type CreateProgressiveBackupUsecase struct{}
func (uc *CreateProgressiveBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -124,7 +125,7 @@ type CreateMediumBackupUsecase struct{}
func (uc *CreateMediumBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -152,7 +153,7 @@ func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
func (m *MockTrackingBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -162,7 +163,7 @@ func (m *MockTrackingBackupUsecase) Execute(
// Send backup ID to channel (non-blocking)
select {
case m.calledBackupIDs <- backupID:
case m.calledBackupIDs <- backup.ID:
default:
}

View File

@@ -13,7 +13,9 @@ import (
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
files_utils "databasus-backend/internal/util/files"
)
const (
@@ -27,6 +29,7 @@ type BackupsScheduler struct {
backupConfigService *backups_config.BackupConfigService
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
lastBackupTime time.Time
logger *slog.Logger
@@ -113,28 +116,28 @@ func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
return len(nodes) > 0
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
return
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
database.ID,
backups_core.BackupStatusInProgress,
)
if err != nil {
s.logger.Error(
"Failed to check for in-progress backups",
"databaseId",
databaseID,
database.ID,
"error",
err,
)
@@ -145,7 +148,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
s.logger.Warn(
"Backup already in progress for database, skipping new backup",
"databaseId",
databaseID,
database.ID,
"existingBackupId",
inProgressBackups[0].ID,
)
@@ -164,13 +167,22 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
fmt.Println("make backup")
backupID := uuid.New()
timestamp := time.Now().UTC()
backup := &backups_core.Backup{
ID: backupID,
FileName: fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(database.Name),
timestamp.Format("20060102-150405"),
backupID.String(),
),
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
CreatedAt: timestamp,
}
if err := s.backupRepository.Save(backup); err != nil {
@@ -224,8 +236,8 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
*leastBusyNodeID,
[]uuid.UUID{backup.ID},
}
}
@@ -329,7 +341,13 @@ func (s *BackupsScheduler) runPendingBackups() error {
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
continue
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}
}

View File

@@ -57,7 +57,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -126,7 +127,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -194,7 +196,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = false
@@ -266,7 +269,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
@@ -339,7 +343,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
@@ -410,7 +415,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -479,7 +485,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -492,7 +499,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -582,7 +589,8 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -595,7 +603,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
assert.NoError(t, err)
// Start a backup and assign it to the node
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -759,7 +767,8 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -872,7 +881,8 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -892,7 +902,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
scheduler.StartBackup(database, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -975,7 +985,8 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -995,7 +1006,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database.ID, false)
scheduler.StartBackup(database, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
@@ -1069,7 +1080,8 @@ func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -1088,7 +1100,7 @@ func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
assert.NoError(t, err)
// Try to start a new backup - should be skipped
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(200 * time.Millisecond)
@@ -1140,7 +1152,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenW
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
@@ -1242,7 +1255,8 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
TimeOfDay: &timeOfDay,
}
backupConfig1.IsBackupsEnabled = true
backupConfig1.StorePeriod = period.PeriodWeek
backupConfig1.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig1.RetentionTimePeriod = period.PeriodWeek
backupConfig1.Storage = storage
backupConfig1.StorageID = &storage.ID
@@ -1259,7 +1273,8 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
TimeOfDay: &timeOfDay,
}
backupConfig2.IsBackupsEnabled = true
backupConfig2.StorePeriod = period.PeriodWeek
backupConfig2.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig2.RetentionTimePeriod = period.PeriodWeek
backupConfig2.Storage = storage
backupConfig2.StorageID = &storage.ID
@@ -1268,10 +1283,10 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
// Start 2 backups simultaneously
t.Log("Starting backup for database1")
scheduler.StartBackup(database1.ID, false)
scheduler.StartBackup(database1, false)
t.Log("Starting backup for database2")
scheduler.StartBackup(database2.ID, false)
scheduler.StartBackup(database2, false)
// Wait up to 10 seconds for both backups to complete
t.Log("Waiting for both backups to complete...")

View File

@@ -1,17 +1,37 @@
package common
import backups_config "databasus-backend/internal/features/backups/config"
import (
backups_config "databasus-backend/internal/features/backups/config"
"errors"
type BackupType string
const (
BackupTypeDefault BackupType = "DEFAULT" // For MySQL, MongoDB, PostgreSQL legacy (-Fc)
BackupTypeDirectory BackupType = "DIRECTORY" // PostgreSQL directory type (-Fd)
"github.com/google/uuid"
)
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
Type BackupType
BackupID uuid.UUID `json:"backupId"`
EncryptionSalt *string `json:"encryptionSalt"`
EncryptionIV *string `json:"encryptionIV"`
Encryption backups_config.BackupEncryption `json:"encryption"`
}
func (m *BackupMetadata) Validate() error {
if m.BackupID == uuid.Nil {
return errors.New("backup ID is required")
}
if m.Encryption == "" {
return errors.New("encryption is required")
}
if m.Encryption == backups_config.BackupEncryptionEncrypted {
if m.EncryptionSalt == nil {
return errors.New("encryption salt is required when encryption is enabled")
}
if m.EncryptionIV == nil {
return errors.New("encryption IV is required when encryption is enabled")
}
}
return nil
}

View File

@@ -6,6 +6,7 @@ import (
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
files_utils "databasus-backend/internal/util/files"
"fmt"
"io"
"net/http"
@@ -304,7 +305,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
_, err = io.Copy(ctx.Writer, rateLimitedReader)
if err != nil {
fmt.Printf("Error streaming file: %v\n", err)
return
}
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
@@ -322,7 +322,7 @@ func (c *BackupController) generateBackupFilename(
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
// Sanitize database name for filename (replace spaces and special chars)
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
// Determine extension based on database type
extension := c.getBackupExtension(database.Type)
@@ -346,33 +346,6 @@ func (c *BackupController) getBackupExtension(
}
}
func sanitizeFilename(name string) string {
// Replace characters that are invalid in filenames
replacer := map[rune]rune{
' ': '_',
'/': '-',
'\\': '-',
':': '-',
'*': '-',
'?': '-',
'"': '-',
'<': '-',
'>': '-',
'|': '-',
}
result := make([]rune, 0, len(name))
for _, char := range name {
if replacement, exists := replacer[char]; exists {
result = append(result, replacement)
} else {
result = append(result, char)
}
}
return string(result)
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()

View File

@@ -7,6 +7,8 @@ import (
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
@@ -18,6 +20,8 @@ import (
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
@@ -32,6 +36,7 @@ import (
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
@@ -956,7 +961,7 @@ func Test_SanitizeFilename(t *testing.T) {
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := sanitizeFilename(tt.input)
result := files_utils.SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
@@ -1244,6 +1249,86 @@ func Test_GenerateDownloadToken_BlockedWhenDownloadInProgress(t *testing.T) {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackup_VerifyBackupAndMetadataFilesExistInStorage(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, _, storage := createTestDatabaseWithBackups(workspace, owner, router)
backuperNode := backuping.CreateTestBackuperNode()
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backupRepo := &backups_core.BackupRepository{}
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+owner.Token,
request,
http.StatusOK,
)
backuping.WaitForBackupCompletion(t, database.ID, len(initialBackups), 30*time.Second)
backups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Greater(t, len(backups), len(initialBackups))
backup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
storageService := storages.GetStorageService()
backupStorage, err := storageService.GetStorageByID(backup.StorageID)
assert.NoError(t, err)
encryptor := encryption.GetFieldEncryptor()
backupFile, err := backupStorage.GetFile(encryptor, backup.FileName)
assert.NoError(t, err)
backupFile.Close()
metadataFile, err := backupStorage.GetFile(encryptor, backup.FileName+".metadata")
assert.NoError(t, err)
metadataContent, err := io.ReadAll(metadataFile)
assert.NoError(t, err)
metadataFile.Close()
var storageMetadata backups_common.BackupMetadata
err = json.Unmarshal(metadataContent, &storageMetadata)
assert.NoError(t, err)
assert.Equal(t, backup.ID, storageMetadata.BackupID)
if backup.EncryptionSalt != nil && storageMetadata.EncryptionSalt != nil {
assert.Equal(t, *backup.EncryptionSalt, *storageMetadata.EncryptionSalt)
}
if backup.EncryptionIV != nil && storageMetadata.EncryptionIV != nil {
assert.Equal(t, *backup.EncryptionIV, *storageMetadata.EncryptionIV)
}
assert.Equal(t, backup.Encryption, storageMetadata.Encryption)
err = backupRepo.DeleteByID(backup.ID)
assert.NoError(t, err)
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
@@ -1407,7 +1492,7 @@ func createTestBackup(
context.Background(),
encryption.GetFieldEncryptor(),
logger,
backup.ID,
backup.ID.String(),
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
@@ -1720,3 +1805,84 @@ func Test_BandwidthThrottling_DynamicAdjustment(t *testing.T) {
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteBackup_RemovesBackupAndMetadataFilesFromDisk(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
backupConfig, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
backupConfig.IsBackupsEnabled = true
backupConfig.StorageID = &storage.ID
backupConfig.Storage = storage
_, err = configService.SaveBackupConfig(backupConfig)
assert.NoError(t, err)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backuperNode := backuping.CreateTestBackuperNode()
backuperCancel := backuping.StartBackuperNodeForTest(t, backuperNode)
defer backuping.StopBackuperNodeForTest(t, backuperCancel, backuperNode)
scheduler := backuping.CreateTestScheduler()
schedulerCancel := backuping.StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backupRepo := &backups_core.BackupRepository{}
initialBackups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
request := MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+owner.Token,
request,
http.StatusOK,
)
backuping.WaitForBackupCompletion(t, database.ID, len(initialBackups), 30*time.Second)
backups, err := backupRepo.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Greater(t, len(backups), len(initialBackups))
backup := backups[0]
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
dataFolder := config.GetEnv().DataFolder
backupFilePath := filepath.Join(dataFolder, backup.FileName)
metadataFilePath := filepath.Join(dataFolder, backup.FileName+".metadata")
_, err = os.Stat(backupFilePath)
assert.NoError(t, err, "backup file should exist on disk before deletion")
_, err = os.Stat(metadataFilePath)
assert.NoError(t, err, "metadata file should exist on disk before deletion")
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusNoContent,
)
_, err = os.Stat(backupFilePath)
assert.True(t, os.IsNotExist(err), "backup file should be removed from disk after deletion")
_, err = os.Stat(metadataFilePath)
assert.True(t, os.IsNotExist(err), "metadata file should be removed from disk after deletion")
}

View File

@@ -8,8 +8,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type NotificationSender interface {
@@ -23,7 +21,7 @@ type NotificationSender interface {
type CreateBackupUsecase interface {
Execute(
ctx context.Context,
backupID uuid.UUID,
backup *Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,

View File

@@ -8,7 +8,8 @@ import (
)
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`

View File

@@ -21,6 +21,7 @@ import (
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"github.com/google/uuid"
)
@@ -92,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
s.backupSchedulerService.StartBackup(databaseID, true)
s.backupSchedulerService.StartBackup(database, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -181,11 +182,7 @@ func (s *BackupService) DeleteBackup(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup deleted for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -232,11 +229,7 @@ func (s *BackupService) CancelBackup(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -276,11 +269,7 @@ func (s *BackupService) GetBackupFile(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -336,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.FileName)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
@@ -490,11 +479,7 @@ func (s *BackupService) WriteAuditLogForDownload(
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backup.ID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&userID,
database.WorkspaceID,
)
@@ -521,7 +506,7 @@ func (s *BackupService) generateBackupFilename(
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
extension := s.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
@@ -12,8 +13,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type CreateBackupUsecase struct {
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypePostgres:
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMysql:
return uc.CreateMysqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMariadb:
return uc.CreateMariadbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMongodb:
return uc.CreateMongodbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMariadbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadbDump,
@@ -115,7 +116,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
args = append(args, "--events")
}
@@ -135,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
func (uc *CreateMariadbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mariadbBin string,
args []string,
@@ -186,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -203,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -419,7 +427,9 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -46,7 +47,7 @@ type writeResult struct {
func (uc *CreateMongodbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongodump,
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
func (uc *CreateMongodbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mongodumpBin string,
args []string,
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -262,6 +269,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
backupMetadata := common.BackupMetadata{
BackupID: backupID,
Encryption: backups_config.BackupEncryptionNone,
}
@@ -298,6 +306,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
backupMetadata.BackupID = backupID
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
backupMetadata.EncryptionSalt = &saltBase64
backupMetadata.EncryptionIV = &nonceBase64

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMysqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMysqlExecutable(
my.Version,
@@ -149,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
func (uc *CreateMysqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mysqlBin string,
args []string,
@@ -200,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -217,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -431,7 +438,9 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -53,7 +54,7 @@ type writeResult struct {
func (uc *CreatePostgresqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetPostgresqlExecutable(
pg.Version,
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
// streamToStorage streams pg_dump output directly to storage
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
pgBin string,
args []string,
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -475,7 +482,9 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -118,9 +118,10 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -146,7 +147,7 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, period.PeriodWeek, response.RetentionTimePeriod)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
@@ -170,9 +171,10 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -337,7 +339,7 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, plan.MaxStoragePeriod, response.StorePeriod)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
@@ -411,9 +413,10 @@ func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testi
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -440,9 +443,10 @@ func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testi
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -469,9 +473,10 @@ func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testi
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodYear, // Exceeds limit of Month
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -498,9 +503,10 @@ func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testi
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek, // Within Month limit
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -529,7 +535,7 @@ func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testi
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.StorePeriod)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
@@ -618,9 +624,10 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -662,9 +669,10 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -959,9 +967,10 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1045,9 +1054,10 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1142,9 +1152,10 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest1 := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1168,9 +1179,10 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
)
backupConfigRequest2 := BackupConfig{
DatabaseID: database2.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database2.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1244,9 +1256,10 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1364,9 +1377,10 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1486,9 +1500,10 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1585,9 +1600,10 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1665,9 +1681,10 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1730,9 +1747,10 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1789,9 +1807,10 @@ func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T
timeOfDay := "04:00"
backupConfigWithRegularStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1840,9 +1859,10 @@ func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T
assert.True(t, savedSystemStorage.IsSystem)
backupConfigWithSystemStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -13,3 +13,11 @@ const (
BackupEncryptionNone BackupEncryption = "NONE"
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
)
type RetentionPolicyType string
const (
RetentionPolicyTypeTimePeriod RetentionPolicyType = "TIME_PERIOD"
RetentionPolicyTypeCount RetentionPolicyType = "COUNT"
RetentionPolicyTypeGFS RetentionPolicyType = "GFS"
)

View File

@@ -18,7 +18,15 @@ type BackupConfig struct {
IsBackupsEnabled bool `json:"isBackupsEnabled" gorm:"column:is_backups_enabled;type:boolean;not null"`
StorePeriod period.Period `json:"storePeriod" gorm:"column:store_period;type:text;not null"`
RetentionPolicyType RetentionPolicyType `json:"retentionPolicyType" gorm:"column:retention_policy_type;type:text;not null;default:'TIME_PERIOD'"`
RetentionTimePeriod period.TimePeriod `json:"retentionTimePeriod" gorm:"column:retention_time_period;type:text;not null;default:''"`
RetentionCount int `json:"retentionCount" gorm:"column:retention_count;type:int;not null;default:0"`
RetentionGfsHours int `json:"retentionGfsHours" gorm:"column:retention_gfs_hours;type:int;not null;default:0"`
RetentionGfsDays int `json:"retentionGfsDays" gorm:"column:retention_gfs_days;type:int;not null;default:0"`
RetentionGfsWeeks int `json:"retentionGfsWeeks" gorm:"column:retention_gfs_weeks;type:int;not null;default:0"`
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
@@ -78,13 +86,12 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
}
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
// Backup interval is required either as ID or as object
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
}
if b.StorePeriod == "" {
return errors.New("store period is required")
if err := b.validateRetentionPolicy(plan); err != nil {
return err
}
if b.IsRetryIfFailed && b.MaxFailedTriesCount <= 0 {
@@ -110,22 +117,12 @@ func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
return errors.New("max backups total size must be non-negative")
}
// Validate against plan limits
// Check storage period limit
if plan.MaxStoragePeriod != period.PeriodForever {
if b.StorePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
// Check max backup size limit (0 in plan means unlimited)
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
// Check max total backups size limit (0 in plan means unlimited)
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
@@ -140,7 +137,14 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
@@ -152,3 +156,34 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
}
}
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
switch b.RetentionPolicyType {
case RetentionPolicyTypeTimePeriod, "":
if b.RetentionTimePeriod == "" {
return errors.New("retention time period is required")
}
if plan.MaxStoragePeriod != period.PeriodForever {
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
case RetentionPolicyTypeCount:
if b.RetentionCount <= 0 {
return errors.New("retention count must be greater than 0")
}
case RetentionPolicyTypeGFS:
if b.RetentionGfsHours <= 0 && b.RetentionGfsDays <= 0 && b.RetentionGfsWeeks <= 0 &&
b.RetentionGfsMonths <= 0 && b.RetentionGfsYears <= 0 {
return errors.New("at least one GFS retention field must be greater than 0")
}
default:
return errors.New("invalid retention policy type")
}
return nil
}

View File

@@ -11,9 +11,9 @@ import (
"github.com/stretchr/testify/assert"
)
func Test_Validate_WhenStoragePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodWeek
config.RetentionTimePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
@@ -22,9 +22,9 @@ func Test_Validate_WhenStoragePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
config.RetentionTimePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
@@ -33,9 +33,11 @@ func Test_Validate_WhenStoragePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsForever_ValidationPasses(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
t *testing.T,
) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
@@ -44,9 +46,9 @@ func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsForever_ValidationPass
assert.NoError(t, err)
}
func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
@@ -55,9 +57,9 @@ func Test_Validate_WhenStoragePeriodIsForeverAndPlanAllowsYear_ValidationFails(t
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenStoragePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodMonth
config.RetentionTimePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
@@ -178,7 +180,7 @@ func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodForever
config.RetentionTimePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
@@ -190,7 +192,7 @@ func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *tes
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = period.PeriodYear
config.RetentionTimePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
@@ -249,14 +251,14 @@ func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *te
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
func Test_Validate_WhenStoragePeriodIsEmpty_ValidationFails(t *testing.T) {
func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = ""
config.RetentionTimePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "store period is required")
assert.EqualError(t, err, "retention time period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
@@ -282,8 +284,8 @@ func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.Period
planPeriod period.Period
configPeriod period.TimePeriod
planPeriod period.TimePeriod
configSize int64
planSize int64
configTotal int64
@@ -345,7 +347,7 @@ func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.StorePeriod = tt.configPeriod
config.RetentionTimePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
@@ -364,12 +366,96 @@ func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
}
}
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "retention count must be greater than 0")
}
func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 10
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 0
config.RetentionGfsWeeks = 0
config.RetentionGfsMonths = 0
config.RetentionGfsYears = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
}
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 7
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
config.RetentionGfsDays = 7
config.RetentionGfsWeeks = 4
config.RetentionGfsMonths = 12
config.RetentionGfsYears = 3
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "invalid retention policy type")
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
StorePeriod: period.PeriodMonth,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},

View File

@@ -227,7 +227,8 @@ func (s *BackupConfigService) initializeDefaultConfig(
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: plan.MaxStoragePeriod,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{

View File

@@ -35,9 +35,10 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -87,9 +88,10 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -131,9 +133,10 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -191,9 +194,10 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -15,9 +15,10 @@ func EnableBackupsForTestDatabase(
timeOfDay := "16:00"
backupConfig := &BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodDay,
DatabaseID: databaseID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodDay,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -25,13 +25,14 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.IsExcludeEvents = incoming.IsExcludeEvents
m.Privileges = incoming.Privileges
if incoming.Password != "" {

View File

@@ -25,15 +25,16 @@ type MongodbDatabase struct {
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port *int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
Host string `json:"host" gorm:"type:text;not null"`
Port *int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
IsDirectConnection bool `json:"isDirectConnection" gorm:"column:is_direct_connection;type:boolean;not null;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
}
func (m *MongodbDatabase) TableName() string {
@@ -132,6 +133,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
m.AuthDatabase = incoming.AuthDatabase
m.IsHttps = incoming.IsHttps
m.IsSrv = incoming.IsSrv
m.IsDirectConnection = incoming.IsDirectConnection
m.CpuCount = incoming.CpuCount
if incoming.Password != "" {
@@ -457,9 +459,12 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB = "admin"
}
tlsParams := ""
extraParams := ""
if m.IsHttps {
tlsParams = "&tls=true&tlsInsecure=true"
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
@@ -470,7 +475,7 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
m.Host,
m.Database,
authDB,
tlsParams,
extraParams,
)
}
@@ -487,7 +492,7 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
port,
m.Database,
authDB,
tlsParams,
extraParams,
)
}
@@ -498,9 +503,12 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB = "admin"
}
tlsParams := ""
extraParams := ""
if m.IsHttps {
tlsParams = "&tls=true&tlsInsecure=true"
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
@@ -510,7 +518,7 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
url.QueryEscape(password),
m.Host,
authDB,
tlsParams,
extraParams,
)
}
@@ -526,7 +534,7 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
m.Host,
port,
authDB,
tlsParams,
extraParams,
)
}

View File

@@ -631,6 +631,89 @@ func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
assert.NoError(t, err)
}
func Test_BuildConnectionURI_WithDirectConnection_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "directConnection=true")
assert.Contains(t, uri, "mongo.example.local:27017")
assert.Contains(t, uri, "authSource=admin")
}
func Test_BuildConnectionURI_WithoutDirectConnection_OmitsParam(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: false,
}
uri := model.buildConnectionURI("testpass123")
assert.NotContains(t, uri, "directConnection")
}
func Test_BuildMongodumpURI_WithDirectConnection_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "directConnection=true")
assert.NotContains(t, uri, "/mydb")
}
func Test_BuildConnectionURI_WithDirectConnectionAndTls_ReturnsBothParams(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: true,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "directConnection=true")
assert.Contains(t, uri, "tls=true")
assert.Contains(t, uri, "tlsInsecure=true")
}
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",

View File

@@ -564,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 4: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
// Step 4: Discover schemas to grant privileges on
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
var rows pgx.Rows
if len(p.IncludeSchemas) > 0 {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
AND schema_name = ANY($1::text[])
`, p.IncludeSchemas)
} else {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
}
if err != nil {
return "", "", fmt.Errorf("failed to get schemas: %w", err)
}
@@ -619,50 +630,197 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
// Step 6: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, grantSelectSQL)
if err != nil {
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 7: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// First, set default privileges for objects created by the current user
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
if err != nil {
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 8: Verify user creation before committing
// Step 8: Discover all roles that own objects in each schema
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
// Filter by IncludeSchemas if specified.
type SchemaOwner struct {
SchemaName string
RoleName string
}
var ownerRows pgx.Rows
if len(p.IncludeSchemas) > 0 {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname = ANY($1::text[])
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`, p.IncludeSchemas)
} else {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`)
}
if err != nil {
// Log warning but continue - this is a best-effort enhancement
logger.Warn("Failed to query object owners for default privileges", "error", err)
} else {
var schemaOwners []SchemaOwner
for ownerRows.Next() {
var so SchemaOwner
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
ownerRows.Close()
logger.Warn("Failed to scan schema owner", "error", err)
break
}
schemaOwners = append(schemaOwners, so)
}
ownerRows.Close()
if err := ownerRows.Err(); err != nil {
logger.Warn("Error iterating schema owners", "error", err)
}
// Step 9: Set default privileges FOR ROLE for each object owner
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
// We log warnings but continue - user creation should succeed even if some roles can't be configured
for _, so := range schemaOwners {
// Try to set default privileges for tables
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (tables)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
// Try to set default privileges for sequences
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (sequences)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
}
if len(schemaOwners) > 0 {
logger.Info(
"Set default privileges for existing object owners",
"readonly_user",
baseUsername,
"owner_count",
len(schemaOwners),
)
}
}
// Step 10: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)

View File

@@ -1319,6 +1319,346 @@ type PostgresContainer struct {
DB *sqlx.DB
}
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create a second database user who will create tables
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err := container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Step 2: Grant the user_creator privileges to connect and create tables
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CREATE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
// Step 2b: Create an initial table by user_creator so they become an object owner
// This is important because our fix discovers existing object owners
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
initialTableName := fmt.Sprintf(
"public.initial_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('initial_data');
`, initialTableName, initialTableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
}()
// Step 3: NOW create read-only user via Databasus (as admin)
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
tableName := fmt.Sprintf(
"public.future_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
`, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
}()
// Step 5: Connect as read-only user and verify it can SELECT from the new table
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
var count int
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
assert.NoError(t, err)
assert.Equal(
t,
2,
count,
"Read-only user should be able to SELECT from table created by different user",
)
// Step 6: Verify read-only user cannot write to the table
_, err = readonlyConn.Exec(
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify pg_dump operations (LOCK TABLE) work
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
tx, err := readonlyConn.Begin()
assert.NoError(t, err)
defer tx.Rollback()
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
err = tx.Commit()
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create multiple schemas and tables
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS included_schema CASCADE;
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
CREATE SCHEMA included_schema;
CREATE SCHEMA excluded_schema;
CREATE TABLE public.public_table (id INT, data TEXT);
INSERT INTO public.public_table VALUES (1, 'public_data');
CREATE TABLE included_schema.included_table (id INT, data TEXT);
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
`)
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
}()
// Step 2: Create a second user who owns tables in both included and excluded schemas
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Grant privileges to user_creator
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
schema,
userCreatorUsername,
))
assert.NoError(t, err)
}
// User_creator creates tables in included and excluded schemas
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.user_table (id INT, data TEXT);
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
`)
assert.NoError(t, err)
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
pgModel := createPostgresModel(container)
pgModel.IncludeSchemas = []string{"public", "included_schema"}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: Connect as read-only user
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
// Step 5: Verify read-only user CAN access included schemas
var publicData string
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "public_data", publicData)
var includedData string
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "included_data", includedData)
var userIncludedData string
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "user_included_data", userIncludedData)
// Step 6: Verify read-only user CANNOT access excluded schema
var excludedData string
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
err = readonlyConn.Get(
&excludedData,
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify future tables in included schemas are accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.future_table (id INT, data TEXT);
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
`)
assert.NoError(t, err)
var futureData string
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(
t,
"future_data",
futureData,
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
)
// Step 8: Verify future tables in excluded schema are NOT accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
`)
assert.NoError(t, err)
var futureExcludedData string
err = readonlyConn.Get(
&futureExcludedData,
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(
t,
err.Error(),
"permission denied",
"Read-only user should NOT access tables in excluded schemas",
)
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"

View File

@@ -192,6 +192,8 @@ func (s *DatabaseService) UpdateDatabase(
}
}
oldName := existingDatabase.Name
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -201,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
if oldName != existingDatabase.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database updated and renamed from '%s' to '%s'",
oldName,
existingDatabase.Name,
),
&user.ID,
existingDatabase.WorkspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
}
return nil
}
@@ -571,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
database.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
nil,
&targetWorkspaceID,
)

View File

@@ -9,17 +9,17 @@ import (
"mime"
"net"
"net/smtp"
"os"
"time"
"github.com/google/uuid"
)
const (
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
DefaultHelloName = "localhost"
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
)
type EmailNotifier struct {
@@ -116,6 +116,16 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
return nil
}
func getHelloName() string {
hostname, err := os.Hostname()
if err != nil || hostname == "" {
return "localhost"
}
return hostname
}
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
func encodeRFC2047(s string) string {
@@ -131,6 +141,7 @@ func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte
encodedSubject := encodeRFC2047(heading)
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
messageID := fmt.Sprintf("Message-ID: <%s@%s>\r\n", uuid.New().String(), e.SMTPHost)
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
@@ -144,7 +155,7 @@ func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + dateHeader + mimeHeaders + message)
return []byte(fromHeader + toHeader + subject + dateHeader + messageID + mimeHeaders + message)
}
func (e *EmailNotifier) sendImplicitTLS(
@@ -219,7 +230,7 @@ func (e *EmailNotifier) createStartTLSClient() (*smtp.Client, func(), error) {
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
if err := client.Hello(getHelloName()); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)

View File

@@ -58,6 +58,8 @@ func (s *NotifierService) SaveNotifier(
return err
}
oldName := existingNotifier.Name
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -67,11 +69,23 @@ func (s *NotifierService) SaveNotifier(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
if oldName != existingNotifier.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Notifier updated and renamed from '%s' to '%s'",
oldName,
existingNotifier.Name,
),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
}
} else {
notifier.WorkspaceID = workspaceID
@@ -343,9 +357,19 @@ func (s *NotifierService) TransferNotifierToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Notifier transferred: %s from workspace '%s' to workspace '%s'",
existingNotifier.Name, sourceWorkspace.Name, targetWorkspace.Name),
&user.ID,
&targetWorkspaceID,
)

View File

@@ -9,9 +9,9 @@ import (
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.Period `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {

View File

@@ -261,7 +261,7 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Database restored from backup") &&
if strings.Contains(log.Message, "Database restored for database") &&
strings.Contains(log.Message, database.Name) {
found = true
break
@@ -752,7 +752,7 @@ func createTestBackup(
context.Background(),
fieldEncryptor,
logger,
backup.ID,
backup.ID.String(),
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))

View File

@@ -190,11 +190,7 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database restored from backup %s for database: %s",
backupID.String(),
database.Name,
),
fmt.Sprintf("Database restored for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -412,11 +408,7 @@ func (s *RestoreService) CancelRestore(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Restore cancelled for database: %s (ID: %s)",
database.Name,
restoreID.String(),
),
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)

View File

@@ -106,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
storage *storages.Storage,
mdbConfig *mariadbtypes.MariadbDatabase,
) error {
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
go func() {
@@ -141,7 +141,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -154,7 +154,7 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
// Stream backup directly from storage
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -105,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
storage *storages.Storage,
myConfig *mysqltypes.MysqlDatabase,
) error {
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
go func() {
@@ -140,7 +140,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}

View File

@@ -152,7 +152,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--no-acl",
}
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
// Monitor for shutdown and parent cancellation
@@ -209,7 +209,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
}
// Get backup stream from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
@@ -429,7 +429,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
// Monitor for shutdown and parent cancellation
@@ -540,12 +540,14 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)

View File

@@ -14,13 +14,13 @@ type StorageFileSaver interface {
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
GetFile(encryptor encryption.FieldEncryptor, fileName string) (io.ReadCloser, error)
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error
Validate(encryptor encryption.FieldEncryptor) error

View File

@@ -41,10 +41,10 @@ func (s *Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileID, file)
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileName, file)
if err != nil {
lastSaveError := err.Error()
s.LastSaveError = &lastSaveError
@@ -58,13 +58,13 @@ func (s *Storage) SaveFile(
func (s *Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(encryptor, fileID)
return s.getSpecificStorage().GetFile(encryptor, fileName)
}
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileName)
}
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {

View File

@@ -229,12 +229,12 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
fileID.String(),
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
file, err := tc.storage.GetFile(encryptor, fileID)
file, err := tc.storage.GetFile(encryptor, fileID.String())
assert.NoError(t, err, "GetFile should succeed")
defer file.Close()
@@ -252,15 +252,15 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
fileID.String(),
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
err = tc.storage.DeleteFile(encryptor, fileID)
err = tc.storage.DeleteFile(encryptor, fileID.String())
assert.NoError(t, err, "DeleteFile should succeed")
file, err := tc.storage.GetFile(encryptor, fileID)
file, err := tc.storage.GetFile(encryptor, fileID.String())
assert.Error(t, err, "GetFile should fail for non-existent file")
if file != nil {
file.Close()
@@ -270,7 +270,7 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
// Try to delete a non-existent file
nonExistentID := uuid.New()
err := tc.storage.DeleteFile(encryptor, nonExistentID)
err := tc.storage.DeleteFile(encryptor, nonExistentID.String())
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
})
})

View File

@@ -68,7 +68,7 @@ func (s *AzureBlobStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -82,7 +82,7 @@ func (s *AzureBlobStorage) SaveFile(
return err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
blockBlobClient := client.ServiceClient().
NewContainerClient(s.ContainerName).
NewBlockBlobClient(blobName)
@@ -157,14 +157,14 @@ func (s *AzureBlobStorage) SaveFile(
func (s *AzureBlobStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
response, err := client.DownloadStream(
context.TODO(),
@@ -179,13 +179,13 @@ func (s *AzureBlobStorage) GetFile(
return response.Body, nil
}
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
blobName := s.buildBlobName(fileName)
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
defer cancel()

View File

@@ -41,7 +41,7 @@ func (f *FTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,19 +50,19 @@ func (f *FTPStorage) SaveFile(
default:
}
logger.Info("Starting to save file to FTP storage", "fileId", fileID.String(), "host", f.Host)
logger.Info("Starting to save file to FTP storage", "fileName", fileName, "host", f.Host)
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to FTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to connect to FTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to connect to FTP: %w", err)
}
defer func() {
if quitErr := conn.Quit(); quitErr != nil {
logger.Error(
"Failed to close FTP connection",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
quitErr,
)
@@ -73,8 +73,8 @@ func (f *FTPStorage) SaveFile(
if err := f.ensureDirectory(conn, f.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
f.Path,
"error",
@@ -84,8 +84,8 @@ func (f *FTPStorage) SaveFile(
}
}
filePath := f.getFilePath(fileID.String())
logger.Debug("Uploading file to FTP", "fileId", fileID.String(), "filePath", filePath)
filePath := f.getFilePath(fileName)
logger.Debug("Uploading file to FTP", "fileName", fileName, "filePath", filePath)
ctxReader := &contextReader{ctx: ctx, reader: file}
@@ -93,18 +93,18 @@ func (f *FTPStorage) SaveFile(
if err != nil {
select {
case <-ctx.Done():
logger.Info("FTP upload cancelled", "fileId", fileID.String())
logger.Info("FTP upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error("Failed to upload file to FTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to upload file to FTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to upload file to FTP: %w", err)
}
}
logger.Info(
"Successfully saved file to FTP storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -113,14 +113,14 @@ func (f *FTPStorage) SaveFile(
func (f *FTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to FTP: %w", err)
}
filePath := f.getFilePath(fileID.String())
filePath := f.getFilePath(fileName)
resp, err := conn.Retr(filePath)
if err != nil {
@@ -134,7 +134,7 @@ func (f *FTPStorage) GetFile(
}, nil
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
defer cancel()
@@ -146,7 +146,7 @@ func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
_ = conn.Quit()
}()
filePath := f.getFilePath(fileID.String())
filePath := f.getFilePath(fileName)
_, err = conn.FileSize(filePath)
if err != nil {

View File

@@ -50,21 +50,19 @@ func (s *GoogleDriveStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
filename := fileID.String()
folderID, err := s.ensureBackupsFolderExists(ctx, driveService)
if err != nil {
return fmt.Errorf("failed to create/find backups folder: %w", err)
}
_ = s.deleteByName(ctx, driveService, filename, folderID)
_ = s.deleteByName(ctx, driveService, fileName, folderID)
fileMeta := &drive.File{
Name: filename,
Name: fileName,
Parents: []string{folderID},
}
@@ -91,7 +89,7 @@ func (s *GoogleDriveStorage) SaveFile(
logger.Info(
"file uploaded to Google Drive",
"name",
filename,
fileName,
"folder",
"databasus_backups",
)
@@ -152,7 +150,7 @@ func (r *backpressureReader) Read(p []byte) (n int, err error) {
func (s *GoogleDriveStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
var result io.ReadCloser
err := s.withRetryOnAuth(
@@ -164,7 +162,7 @@ func (s *GoogleDriveStorage) GetFile(
return fmt.Errorf("failed to find backups folder: %w", err)
}
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
fileIDGoogle, err := s.lookupFileID(driveService, fileName, folderID)
if err != nil {
return err
}
@@ -184,7 +182,7 @@ func (s *GoogleDriveStorage) GetFile(
func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) error {
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
defer cancel()
@@ -195,7 +193,7 @@ func (s *GoogleDriveStorage) DeleteFile(
return fmt.Errorf("failed to find backups folder: %w", err)
}
return s.deleteByName(ctx, driveService, fileID.String(), folderID)
return s.deleteByName(ctx, driveService, fileName, folderID)
})
}

View File

@@ -36,7 +36,7 @@ func (l *LocalStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -45,7 +45,7 @@ func (l *LocalStorage) SaveFile(
default:
}
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
logger.Info("Starting to save file to local storage", "fileName", fileName)
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
@@ -54,15 +54,15 @@ func (l *LocalStorage) SaveFile(
return fmt.Errorf("failed to ensure directories: %w", err)
}
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileID.String())
logger.Debug("Creating temp file", "fileId", fileID.String(), "tempPath", tempFilePath)
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
tempFile, err := os.Create(tempFilePath)
if err != nil {
logger.Error(
"Failed to create temp file",
"fileId",
fileID.String(),
"fileName",
fileName,
"tempPath",
tempFilePath,
"error",
@@ -74,29 +74,29 @@ func (l *LocalStorage) SaveFile(
_ = tempFile.Close()
}()
logger.Debug("Copying file data to temp file", "fileId", fileID.String())
logger.Debug("Copying file data to temp file", "fileName", fileName)
_, err = copyWithContext(ctx, tempFile, file)
if err != nil {
logger.Error("Failed to write to temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to write to temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to write to temp file: %w", err)
}
if err = tempFile.Sync(); err != nil {
logger.Error("Failed to sync temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to sync temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to sync temp file: %w", err)
}
// Close the temp file explicitly before moving it (required on Windows)
if err = tempFile.Close(); err != nil {
logger.Error("Failed to close temp file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to close temp file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to close temp file: %w", err)
}
finalPath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
finalPath := filepath.Join(config.GetEnv().DataFolder, fileName)
logger.Debug(
"Moving file from temp to final location",
"fileId",
fileID.String(),
"fileName",
fileName,
"finalPath",
finalPath,
)
@@ -105,8 +105,8 @@ func (l *LocalStorage) SaveFile(
if err = os.Rename(tempFilePath, finalPath); err != nil {
logger.Error(
"Failed to move file from temp to backups",
"fileId",
fileID.String(),
"fileName",
fileName,
"tempPath",
tempFilePath,
"finalPath",
@@ -119,8 +119,8 @@ func (l *LocalStorage) SaveFile(
logger.Info(
"Successfully saved file to local storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"finalPath",
finalPath,
)
@@ -130,12 +130,12 @@ func (l *LocalStorage) SaveFile(
func (l *LocalStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil, fmt.Errorf("file not found: %s", fileID.String())
return nil, fmt.Errorf("file not found: %s", fileName)
}
file, err := os.Open(filePath)
@@ -146,8 +146,8 @@ func (l *LocalStorage) GetFile(
return file, nil
}
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil

View File

@@ -46,7 +46,7 @@ func (n *NASStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -55,19 +55,19 @@ func (n *NASStorage) SaveFile(
default:
}
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
logger.Info("Starting to save file to NAS storage", "fileName", fileName, "host", n.Host)
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create NAS session", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create NAS session: %w", err)
}
defer func() {
if logoffErr := session.Logoff(); logoffErr != nil {
logger.Error(
"Failed to logoff NAS session",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
logoffErr,
)
@@ -78,8 +78,8 @@ func (n *NASStorage) SaveFile(
if err != nil {
logger.Error(
"Failed to mount NAS share",
"fileId",
fileID.String(),
"fileName",
fileName,
"share",
n.Share,
"error",
@@ -91,8 +91,8 @@ func (n *NASStorage) SaveFile(
if umountErr := fs.Umount(); umountErr != nil {
logger.Error(
"Failed to unmount NAS share",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
umountErr,
)
@@ -104,8 +104,8 @@ func (n *NASStorage) SaveFile(
if err := n.ensureDirectory(fs, n.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
n.Path,
"error",
@@ -115,15 +115,15 @@ func (n *NASStorage) SaveFile(
}
}
filePath := n.getFilePath(fileID.String())
logger.Debug("Creating file on NAS", "fileId", fileID.String(), "filePath", filePath)
filePath := n.getFilePath(fileName)
logger.Debug("Creating file on NAS", "fileName", fileName, "filePath", filePath)
nasFile, err := fs.Create(filePath)
if err != nil {
logger.Error(
"Failed to create file on NAS",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
"error",
@@ -133,21 +133,21 @@ func (n *NASStorage) SaveFile(
}
defer func() {
if closeErr := nasFile.Close(); closeErr != nil {
logger.Error("Failed to close NAS file", "fileId", fileID.String(), "error", closeErr)
logger.Error("Failed to close NAS file", "fileName", fileName, "error", closeErr)
}
}()
logger.Debug("Copying file data to NAS", "fileId", fileID.String())
logger.Debug("Copying file data to NAS", "fileName", fileName)
_, err = copyWithContext(ctx, nasFile, file)
if err != nil {
logger.Error("Failed to write file to NAS", "fileId", fileID.String(), "error", err)
logger.Error("Failed to write file to NAS", "fileName", fileName, "error", err)
return fmt.Errorf("failed to write file to NAS: %w", err)
}
logger.Info(
"Successfully saved file to NAS storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -156,7 +156,7 @@ func (n *NASStorage) SaveFile(
func (n *NASStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
session, err := n.createSession(encryptor)
if err != nil {
@@ -169,14 +169,14 @@ func (n *NASStorage) GetFile(
return nil, fmt.Errorf("failed to mount share '%s': %w", n.Share, err)
}
filePath := n.getFilePath(fileID.String())
filePath := n.getFilePath(fileName)
// Check if file exists
_, err = fs.Stat(filePath)
if err != nil {
_ = fs.Umount()
_ = session.Logoff()
return nil, fmt.Errorf("file not found: %s", fileID.String())
return nil, fmt.Errorf("file not found: %s", fileName)
}
nasFile, err := fs.Open(filePath)
@@ -194,7 +194,7 @@ func (n *NASStorage) GetFile(
}, nil
}
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
defer cancel()
@@ -214,7 +214,7 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
_ = fs.Umount()
}()
filePath := n.getFilePath(fileID.String())
filePath := n.getFilePath(fileName)
_, err = fs.Stat(filePath)
if err != nil {

View File

@@ -41,7 +41,7 @@ func (r *RcloneStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,28 +50,28 @@ func (r *RcloneStorage) SaveFile(
default:
}
logger.Info("Starting to save file to rclone storage", "fileId", fileID.String())
logger.Info("Starting to save file to rclone storage", "fileName", fileName)
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
logger.Error("Failed to create rclone filesystem", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create rclone filesystem", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
logger.Debug("Uploading file via rclone", "fileId", fileID.String(), "filePath", filePath)
filePath := r.getFilePath(fileName)
logger.Debug("Uploading file via rclone", "fileName", fileName, "filePath", filePath)
_, err = operations.Rcat(ctx, remoteFs, filePath, io.NopCloser(file), time.Now().UTC(), nil)
if err != nil {
select {
case <-ctx.Done():
logger.Info("Rclone upload cancelled", "fileId", fileID.String())
logger.Info("Rclone upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error(
"Failed to upload file via rclone",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
err,
)
@@ -81,8 +81,8 @@ func (r *RcloneStorage) SaveFile(
logger.Info(
"Successfully saved file to rclone storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -91,7 +91,7 @@ func (r *RcloneStorage) SaveFile(
func (r *RcloneStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
ctx := context.Background()
@@ -100,7 +100,7 @@ func (r *RcloneStorage) GetFile(
return nil, fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
filePath := r.getFilePath(fileName)
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {
@@ -115,7 +115,7 @@ func (r *RcloneStorage) GetFile(
return reader, nil
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
defer cancel()
@@ -124,7 +124,7 @@ func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID u
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
filePath := r.getFilePath(fileName)
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {

View File

@@ -55,7 +55,7 @@ func (s *S3Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -69,7 +69,7 @@ func (s *S3Storage) SaveFile(
return err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
uploadID, err := coreClient.NewMultipartUpload(
ctx,
@@ -184,14 +184,14 @@ func (s *S3Storage) SaveFile(
func (s *S3Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
object, err := client.GetObject(
context.TODO(),
@@ -221,13 +221,13 @@ func (s *S3Storage) GetFile(
return object, nil
}
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
objectKey := s.buildObjectKey(fileID.String())
objectKey := s.buildObjectKey(fileName)
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
defer cancel()

View File

@@ -41,7 +41,7 @@ func (s *SFTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
fileName string,
file io.Reader,
) error {
select {
@@ -50,19 +50,19 @@ func (s *SFTPStorage) SaveFile(
default:
}
logger.Info("Starting to save file to SFTP storage", "fileId", fileID.String(), "host", s.Host)
logger.Info("Starting to save file to SFTP storage", "fileName", fileName, "host", s.Host)
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to SFTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to connect to SFTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to connect to SFTP: %w", err)
}
defer func() {
if closeErr := client.Close(); closeErr != nil {
logger.Error(
"Failed to close SFTP client",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
closeErr,
)
@@ -70,8 +70,8 @@ func (s *SFTPStorage) SaveFile(
if closeErr := sshConn.Close(); closeErr != nil {
logger.Error(
"Failed to close SSH connection",
"fileId",
fileID.String(),
"fileName",
fileName,
"error",
closeErr,
)
@@ -82,8 +82,8 @@ func (s *SFTPStorage) SaveFile(
if err := s.ensureDirectory(client, s.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"fileName",
fileName,
"path",
s.Path,
"error",
@@ -93,12 +93,12 @@ func (s *SFTPStorage) SaveFile(
}
}
filePath := s.getFilePath(fileID.String())
logger.Debug("Uploading file to SFTP", "fileId", fileID.String(), "filePath", filePath)
filePath := s.getFilePath(fileName)
logger.Debug("Uploading file to SFTP", "fileName", fileName, "filePath", filePath)
remoteFile, err := client.Create(filePath)
if err != nil {
logger.Error("Failed to create remote file", "fileId", fileID.String(), "error", err)
logger.Error("Failed to create remote file", "fileName", fileName, "error", err)
return fmt.Errorf("failed to create remote file: %w", err)
}
defer func() {
@@ -111,18 +111,18 @@ func (s *SFTPStorage) SaveFile(
if err != nil {
select {
case <-ctx.Done():
logger.Info("SFTP upload cancelled", "fileId", fileID.String())
logger.Info("SFTP upload cancelled", "fileName", fileName)
return ctx.Err()
default:
logger.Error("Failed to upload file to SFTP", "fileId", fileID.String(), "error", err)
logger.Error("Failed to upload file to SFTP", "fileName", fileName, "error", err)
return fmt.Errorf("failed to upload file to SFTP: %w", err)
}
}
logger.Info(
"Successfully saved file to SFTP storage",
"fileId",
fileID.String(),
"fileName",
fileName,
"filePath",
filePath,
)
@@ -131,14 +131,14 @@ func (s *SFTPStorage) SaveFile(
func (s *SFTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
fileName string,
) (io.ReadCloser, error) {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to SFTP: %w", err)
}
filePath := s.getFilePath(fileID.String())
filePath := s.getFilePath(fileName)
remoteFile, err := client.Open(filePath)
if err != nil {
@@ -154,7 +154,7 @@ func (s *SFTPStorage) GetFile(
}, nil
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
defer cancel()
@@ -167,7 +167,7 @@ func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uui
_ = sshConn.Close()
}()
filePath := s.getFilePath(fileID.String())
filePath := s.getFilePath(fileName)
_, err = client.Stat(filePath)
if err != nil {

View File

@@ -92,6 +92,8 @@ func (s *StorageService) SaveStorage(
existingStorage.Update(storage)
oldName := existingStorage.Name
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
@@ -105,11 +107,19 @@ func (s *StorageService) SaveStorage(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
&user.ID,
&workspaceID,
)
if oldName != existingStorage.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage renamed from '%s' to '%s'", oldName, existingStorage.Name),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
&user.ID,
&workspaceID,
)
}
} else {
storage.WorkspaceID = workspaceID
@@ -368,9 +378,26 @@ func (s *StorageService) TransferStorageToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage transferred: %s from workspace %s to workspace %s",
existingStorage.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Storage transferred out: %s to workspace '%s'",
existingStorage.Name, targetWorkspace.Name),
&user.ID,
&sourceWorkspaceID,
)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage transferred in: %s from workspace '%s'",
existingStorage.Name, sourceWorkspace.Name),
&user.ID,
&targetWorkspaceID,
)

View File

@@ -147,6 +147,26 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
}
}
func Test_BackupAndRestoreMariadb_WithExcludeEvents_EventsNotRestored(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testMariadbBackupRestoreWithExcludeEventsForVersion(t, tc.version, tc.port)
})
}
}
func testMariadbBackupRestoreForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
@@ -702,3 +722,145 @@ func updateMariadbDatabaseCredentialsViaAPI(
return &updatedDatabase
}
func testMariadbBackupRestoreWithExcludeEventsForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
port string,
) {
container, err := connectToMariadbContainer(mariadbVersion, port)
if err != nil {
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
return
}
defer func() {
if container.DB != nil {
container.DB.Close()
}
}()
setupMariadbTestData(t, container.DB)
_, err = container.DB.Exec(`
CREATE EVENT IF NOT EXISTS test_event
ON SCHEDULE EVERY 1 DAY
DO BEGIN
INSERT INTO test_data (name, value) VALUES ('event_test', 999);
END
`)
if err != nil {
t.Skipf(
"Skipping test: MariaDB version doesn't support events or event scheduler disabled: %v",
err,
)
return
}
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace(
"MariaDB Exclude Events Test Workspace",
user,
router,
)
storage := storages.CreateTestStorage(workspace.ID)
database := createMariadbDatabaseViaAPI(
t, router, "MariaDB Exclude Events Test Database", workspace.ID,
container.Host, container.Port,
container.Username, container.Password, container.Database,
container.Version,
user.Token,
)
database.Mariadb.IsExcludeEvents = true
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/update",
"Bearer "+user.Token,
database,
)
if w.Code != http.StatusOK {
t.Fatalf(
"Failed to update database with IsExcludeEvents. Status: %d, Body: %s",
w.Code,
w.Body.String(),
)
}
enableBackupsViaAPI(
t, router, database.ID, storage.ID,
backups_config.BackupEncryptionNone, user.Token,
)
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_no_events"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, newDBName)
newDB, err := sqlx.Connect("mysql", newDSN)
assert.NoError(t, err)
defer newDB.Close()
createMariadbRestoreViaAPI(
t, router, backup.ID,
container.Host, container.Port,
container.Username, container.Password, newDBName,
container.Version,
user.Token,
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
&tableExists,
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
verifyMariadbDataIntegrity(t, container.DB, newDB)
var eventCount int
err = newDB.Get(
&eventCount,
"SELECT COUNT(*) FROM information_schema.events WHERE event_schema = ? AND event_name = 'test_event'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(
t,
0,
eventCount,
"Event 'test_event' should NOT exist in restored database when IsExcludeEvents is true",
)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
t.Logf("Warning: Failed to delete backup file: %v", err)
}
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+user.Token,
http.StatusNoContent,
)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -726,41 +726,28 @@ func Test_InviteUserToWorkspace_MembershipReceivedAfterSignUp(t *testing.T) {
assert.Equal(t, workspaces_dto.AddStatusInvited, inviteResponse.Status)
// 3. Sign up the invited user
// 3. Sign up the invited user (now returns token directly)
signUpRequest := users_dto.SignUpRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
Name: "Invited User",
}
resp := test_utils.MakePostRequest(
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
signUpRequest,
http.StatusOK,
)
assert.Contains(t, string(resp.Body), "User created successfully")
// 4. Sign in the newly registered user
signInRequest := users_dto.SignInRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
}
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
signInRequest,
http.StatusOK,
&signInResponse,
)
// 5. Verify user is automatically added as member to workspace
assert.NotEmpty(t, signInResponse.Token)
assert.Equal(t, inviteEmail, signInResponse.Email)
// 4. Verify user is automatically added as member to workspace
var membersResponse workspaces_dto.GetMembersResponseDTO
test_utils.MakeGetRequestAndUnmarshal(
t,

View File

@@ -11,6 +11,7 @@ import (
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
cache_utils "databasus-backend/internal/util/cache"
cloudflare_turnstile "databasus-backend/internal/util/cloudflare_turnstile"
"github.com/gin-gonic/gin"
)
@@ -51,7 +52,7 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param request body users_dto.SignUpRequestDTO true "User signup data"
// @Success 200
// @Success 200 {object} users_dto.SignInResponseDTO
// @Failure 400
// @Router /users/signup [post]
func (c *UserController) SignUp(ctx *gin.Context) {
@@ -61,13 +62,41 @@ func (c *UserController) SignUp(ctx *gin.Context) {
return
}
err := c.userService.SignUp(&request)
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
user, err := c.userService.SignUp(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
response, err := c.userService.GenerateAccessToken(user)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
ctx.JSON(http.StatusOK, response)
}
// SignIn
@@ -88,6 +117,28 @@ func (c *UserController) SignIn(ctx *gin.Context) {
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
if !allowed {
ctx.JSON(
@@ -363,6 +414,28 @@ func (c *UserController) SendResetPasswordCode(ctx *gin.Context) {
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(
request.Email,
"reset-password",

View File

@@ -27,7 +27,20 @@ func Test_SignUpUser_WithValidData_UserCreated(t *testing.T) {
Name: "Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", request, http.StatusOK)
var response users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
request,
http.StatusOK,
&response,
)
assert.NotEmpty(t, response.Token)
assert.NotEqual(t, uuid.Nil, response.UserID)
assert.Equal(t, request.Email, response.Email)
}
func Test_SignUpUser_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {

View File

@@ -9,14 +9,16 @@ import (
)
type SignUpRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInResponseDTO struct {
@@ -94,7 +96,8 @@ type OAuthCallbackResponseDTO struct {
}
type SendResetPasswordCodeRequestDTO struct {
Email string `json:"email" binding:"required,email"`
Email string `json:"email" binding:"required,email"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type ResetPasswordRequestDTO struct {

View File

@@ -44,19 +44,19 @@ func (s *UserService) SetEmailSender(sender users_interfaces.EmailSender) {
s.emailSender = sender
}
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) (*users_models.User, error) {
existingUser, err := s.userRepository.GetUserByEmail(request.Email)
if err != nil {
return fmt.Errorf("failed to check existing user: %w", err)
return nil, fmt.Errorf("failed to check existing user: %w", err)
}
if existingUser != nil && existingUser.Status != users_enums.UserStatusInvited {
return errors.New("user with this email already exists")
return nil, errors.New("user with this email already exists")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
return nil, fmt.Errorf("failed to hash password: %w", err)
}
hashedPasswordStr := string(hashedPassword)
@@ -67,39 +67,45 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
existingUser.ID,
hashedPasswordStr,
); err != nil {
return fmt.Errorf("failed to set password: %w", err)
return nil, fmt.Errorf("failed to set password: %w", err)
}
if err := s.userRepository.UpdateUserStatus(
existingUser.ID,
users_enums.UserStatusActive,
); err != nil {
return fmt.Errorf("failed to activate user: %w", err)
return nil, fmt.Errorf("failed to activate user: %w", err)
}
name := request.Name
if err := s.userRepository.UpdateUserInfo(existingUser.ID, &name, nil); err != nil {
return fmt.Errorf("failed to update name: %w", err)
return nil, fmt.Errorf("failed to update name: %w", err)
}
// Fetch updated user to ensure we have the latest data
updatedUser, err := s.userRepository.GetUserByID(existingUser.ID)
if err != nil {
return nil, fmt.Errorf("failed to get updated user: %w", err)
}
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Invited user completed registration: %s", existingUser.Email),
&existingUser.ID,
fmt.Sprintf("Invited user completed registration: %s", updatedUser.Email),
&updatedUser.ID,
nil,
)
return nil
return updatedUser, nil
}
// Get settings to check registration policy for new users
settings, err := s.settingsService.GetSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
return nil, fmt.Errorf("failed to get settings: %w", err)
}
// Check if external registrations are allowed
if !settings.IsAllowExternalRegistrations {
return errors.New("external registration is disabled")
return nil, errors.New("external registration is disabled")
}
user := &users_models.User{
@@ -114,7 +120,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
}
if err := s.userRepository.CreateUser(user); err != nil {
return fmt.Errorf("failed to create user: %w", err)
return nil, fmt.Errorf("failed to create user: %w", err)
}
s.auditLogWriter.WriteAuditLog(
@@ -123,7 +129,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
nil,
)
return nil
return user, nil
}
func (s *UserService) SignIn(
@@ -258,6 +264,7 @@ func (s *UserService) GenerateAccessToken(
return &users_dto.SignInResponseDTO{
UserID: user.ID,
Email: user.Email,
Token: tokenString,
}, nil
}
@@ -383,7 +390,7 @@ func (s *UserService) InviteUser(
message := fmt.Sprintf("User invited: %s", request.Email)
if request.IntendedWorkspaceID != nil {
message += fmt.Sprintf(" for workspace %s", request.IntendedWorkspaceID.String())
message += " for workspace"
}
s.auditLogWriter.WriteAuditLog(
message,
@@ -430,6 +437,9 @@ func (s *UserService) UpdateUserInfo(
return fmt.Errorf("failed to get user: %w", err)
}
oldEmail := user.Email
oldName := user.Name
if user.Email == "admin" && request.Email != nil && *request.Email != user.Email {
return errors.New("admin email cannot be changed")
}
@@ -448,7 +458,28 @@ func (s *UserService) UpdateUserInfo(
return fmt.Errorf("failed to update user info: %w", err)
}
s.auditLogWriter.WriteAuditLog("User info updated", &userID, nil)
var auditMessages []string
if request.Email != nil && *request.Email != oldEmail {
auditMessages = append(
auditMessages,
fmt.Sprintf("Email changed from '%s' to '%s'", oldEmail, *request.Email),
)
}
if request.Name != nil && *request.Name != oldName {
auditMessages = append(
auditMessages,
fmt.Sprintf("Name changed from '%s' to '%s'", oldName, *request.Name),
)
}
if len(auditMessages) > 0 {
for _, message := range auditMessages {
s.auditLogWriter.WriteAuditLog(message, &userID, nil)
}
} else {
s.auditLogWriter.WriteAuditLog("User info updated", &userID, nil)
}
return nil
}
@@ -463,6 +494,178 @@ func (s *UserService) HandleGitHubOAuth(
)
}
func (s *UserService) HandleGoogleOAuth(
code, redirectUri string,
) (*users_dto.OAuthCallbackResponseDTO, error) {
return s.handleGoogleOAuthWithEndpoint(
code,
redirectUri,
google.Endpoint,
"https://www.googleapis.com/oauth2/v2/userinfo",
)
}
func (s *UserService) SendResetPasswordCode(email string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Silently succeed for non-existent users to prevent enumeration attacks
if user == nil {
return nil
}
// Only active users can reset passwords
if user.Status != users_enums.UserStatusActive {
return errors.New("only active users can reset their password")
}
// Check rate limiting - max 3 codes per hour
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
if err != nil {
return fmt.Errorf("failed to check rate limit: %w", err)
}
if recentCount >= 3 {
return errors.New("too many password reset attempts, please try again later")
}
// Generate 6-digit random code using crypto/rand for better randomness
codeNum := make([]byte, 4)
_, err = io.ReadFull(rand.Reader, codeNum)
if err != nil {
return fmt.Errorf("failed to generate random code: %w", err)
}
// Convert bytes to uint32 and modulo to get 6 digits
randomInt := uint32(
codeNum[0],
)<<24 | uint32(
codeNum[1],
)<<16 | uint32(
codeNum[2],
)<<8 | uint32(
codeNum[3],
)
code := fmt.Sprintf("%06d", randomInt%1000000)
// Hash the code
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash code: %w", err)
}
// Store in database with 1 hour expiration
resetCode := &users_models.PasswordResetCode{
ID: uuid.New(),
UserID: user.ID,
HashedCode: string(hashedCode),
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
IsUsed: false,
CreatedAt: time.Now().UTC(),
}
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
return fmt.Errorf("failed to create reset code: %w", err)
}
// Send email with code
if s.emailSender != nil {
subject := "Password Reset Code"
body := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
You have requested to reset your password. Please use the following code to complete the password reset process:
</p>
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
</div>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
This code will expire in <strong>1 hour</strong>.
</p>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
</p>
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
This is an automated message. Please do not reply to this email.
</p>
</div>
</body>
</html>
`, code)
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Password reset code sent to: %s", user.Email),
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) ResetPassword(email, code, newPassword string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return errors.New("user with this email does not exist")
}
// Get valid reset code for user
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
if err != nil {
return errors.New("invalid or expired reset code")
}
// Verify code matches
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
if err != nil {
return errors.New("invalid reset code")
}
// Mark code as used
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
}
// Update user password
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
"Password reset via email code",
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) handleGitHubOAuthWithEndpoint(
code, redirectUri string,
endpoint oauth2.Endpoint,
@@ -529,17 +732,6 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
return s.getOrCreateUserFromOAuth(oauthID, email, name, "github")
}
func (s *UserService) HandleGoogleOAuth(
code, redirectUri string,
) (*users_dto.OAuthCallbackResponseDTO, error) {
return s.handleGoogleOAuthWithEndpoint(
code,
redirectUri,
google.Endpoint,
"https://www.googleapis.com/oauth2/v2/userinfo",
)
}
func (s *UserService) handleGoogleOAuthWithEndpoint(
code, redirectUri string,
endpoint oauth2.Endpoint,
@@ -805,164 +997,3 @@ func (s *UserService) fetchGitHubPrimaryEmail(
return "", errors.New("github account has no accessible email")
}
func (s *UserService) SendResetPasswordCode(email string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Silently succeed for non-existent users to prevent enumeration attacks
if user == nil {
return nil
}
// Only active users can reset passwords
if user.Status != users_enums.UserStatusActive {
return errors.New("only active users can reset their password")
}
// Check rate limiting - max 3 codes per hour
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
if err != nil {
return fmt.Errorf("failed to check rate limit: %w", err)
}
if recentCount >= 3 {
return errors.New("too many password reset attempts, please try again later")
}
// Generate 6-digit random code using crypto/rand for better randomness
codeNum := make([]byte, 4)
_, err = io.ReadFull(rand.Reader, codeNum)
if err != nil {
return fmt.Errorf("failed to generate random code: %w", err)
}
// Convert bytes to uint32 and modulo to get 6 digits
randomInt := uint32(
codeNum[0],
)<<24 | uint32(
codeNum[1],
)<<16 | uint32(
codeNum[2],
)<<8 | uint32(
codeNum[3],
)
code := fmt.Sprintf("%06d", randomInt%1000000)
// Hash the code
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash code: %w", err)
}
// Store in database with 1 hour expiration
resetCode := &users_models.PasswordResetCode{
ID: uuid.New(),
UserID: user.ID,
HashedCode: string(hashedCode),
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
IsUsed: false,
CreatedAt: time.Now().UTC(),
}
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
return fmt.Errorf("failed to create reset code: %w", err)
}
// Send email with code
if s.emailSender != nil {
subject := "Password Reset Code"
body := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
You have requested to reset your password. Please use the following code to complete the password reset process:
</p>
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
</div>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
This code will expire in <strong>1 hour</strong>.
</p>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
</p>
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
This is an automated message. Please do not reply to this email.
</p>
</div>
</body>
</html>
`, code)
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Password reset code sent to: %s", user.Email),
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) ResetPassword(email, code, newPassword string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return errors.New("user with this email does not exist")
}
// Get valid reset code for user
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
if err != nil {
return errors.New("invalid or expired reset code")
}
// Verify code matches
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
if err != nil {
return errors.New("invalid reset code")
}
// Mark code as used
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
}
// Update user password
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
"Password reset via email code",
&user.ID,
nil,
)
}
return nil
}

View File

@@ -129,6 +129,8 @@ func (s *WorkspaceService) UpdateWorkspace(
return nil, fmt.Errorf("failed to get workspace: %w", err)
}
oldName := existingWorkspace.Name
updateDTO.ID = workspaceID
updateDTO.CreatedAt = existingWorkspace.CreatedAt
@@ -138,11 +140,19 @@ func (s *WorkspaceService) UpdateWorkspace(
return nil, fmt.Errorf("failed to update workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Workspace updated: %s", updateDTO.Name),
&user.ID,
&workspaceID,
)
if oldName != updateDTO.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Workspace updated and renamed from '%s' to '%s'", oldName, updateDTO.Name),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Workspace updated: %s", updateDTO.Name),
&user.ID,
&workspaceID,
)
}
return existingWorkspace, nil
}

View File

@@ -0,0 +1,71 @@
package cloudflare_turnstile
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
)
type CloudflareTurnstileService struct {
secretKey string
siteKey string
}
type cloudflareTurnstileResponse struct {
Success bool `json:"success"`
ChallengeTS time.Time `json:"challenge_ts"`
Hostname string `json:"hostname"`
ErrorCodes []string `json:"error-codes"`
}
const cloudflareTurnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
func (s *CloudflareTurnstileService) IsEnabled() bool {
return s.secretKey != ""
}
func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool, error) {
if !s.IsEnabled() {
return true, nil
}
if token == "" {
return false, errors.New("cloudflare Turnstile token is required")
}
formData := url.Values{}
formData.Set("secret", s.secretKey)
formData.Set("response", token)
formData.Set("remoteip", remoteIP)
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
if err != nil {
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("failed to read Cloudflare Turnstile response: %w", err)
}
var turnstileResp cloudflareTurnstileResponse
if err := json.Unmarshal(body, &turnstileResp); err != nil {
return false, fmt.Errorf("failed to parse Cloudflare Turnstile response: %w", err)
}
if !turnstileResp.Success {
return false, fmt.Errorf(
"cloudflare Turnstile verification failed: %v",
turnstileResp.ErrorCodes,
)
}
return true, nil
}

View File

@@ -0,0 +1,14 @@
package cloudflare_turnstile
import (
"databasus-backend/internal/config"
)
var cloudflareTurnstileService = &CloudflareTurnstileService{
config.GetEnv().CloudflareTurnstileSecretKey,
config.GetEnv().CloudflareTurnstileSiteKey,
}
func GetCloudflareTurnstileService() *CloudflareTurnstileService {
return cloudflareTurnstileService
}

View File

@@ -0,0 +1,48 @@
package files_utils
// SanitizeFilename replaces characters that are invalid or problematic in filenames
// across different operating systems (Windows, Linux, macOS) and storage systems
// (local filesystem, S3, FTP, SFTP, NAS, rclone, Azure Blob, Google Drive).
//
// The following characters are replaced:
// - Space (' ') -> underscore ('_')
// - Forward slash ('/') -> hyphen ('-')
// - Backslash ('\') -> hyphen ('-')
// - Colon (':') -> hyphen ('-')
// - Asterisk ('*') -> hyphen ('-')
// - Question mark ('?') -> hyphen ('-')
// - Double quote ('"') -> hyphen ('-')
// - Less than ('<') -> hyphen ('-')
// - Greater than ('>') -> hyphen ('-')
// - Pipe ('|') -> hyphen ('-')
//
// This ensures filenames work correctly on:
// - Windows (strict filename rules)
// - Unix/Linux/macOS (forward slashes are path separators)
// - All cloud storage providers (S3, Azure Blob, Google Drive)
// - Network storage (FTP, SFTP, NAS, rclone)
func SanitizeFilename(name string) string {
replacer := map[rune]rune{
' ': '_',
'/': '-',
'\\': '-',
':': '-',
'*': '-',
'?': '-',
'"': '-',
'<': '-',
'>': '-',
'|': '-',
}
result := make([]rune, 0, len(name))
for _, char := range name {
if replacement, exists := replacer[char]; exists {
result = append(result, replacement)
} else {
result = append(result, char)
}
}
return string(result)
}

View File

@@ -0,0 +1,217 @@
package files_utils
import (
"testing"
"github.com/stretchr/testify/assert"
)
func Test_SanitizeFilename_ReplacesSpecialCharacters(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "replaces spaces with underscores",
input: "my database name",
expected: "my_database_name",
},
{
name: "replaces forward slashes",
input: "db/prod/main",
expected: "db-prod-main",
},
{
name: "replaces backslashes",
input: "db\\prod\\main",
expected: "db-prod-main",
},
{
name: "replaces colons",
input: "db:production:main",
expected: "db-production-main",
},
{
name: "replaces asterisks",
input: "db*wildcard",
expected: "db-wildcard",
},
{
name: "replaces question marks",
input: "db?query",
expected: "db-query",
},
{
name: "replaces double quotes",
input: "db\"quoted\"name",
expected: "db-quoted-name",
},
{
name: "replaces less than signs",
input: "db<redirect",
expected: "db-redirect",
},
{
name: "replaces greater than signs",
input: "db>output",
expected: "db-output",
},
{
name: "replaces pipes",
input: "db|pipe",
expected: "db-pipe",
},
{
name: "replaces multiple different special characters",
input: "my db:/backup\\file*2024?",
expected: "my_db--backup-file-2024-",
},
{
name: "handles all special characters at once",
input: " /\\:*?\"<>|",
expected: "_---------",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_SanitizeFilename_HandlesEdgeCases(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "empty string returns empty string",
input: "",
expected: "",
},
{
name: "string with no special characters remains unchanged",
input: "simple_database_name",
expected: "simple_database_name",
},
{
name: "string with hyphens and underscores remains unchanged",
input: "my-database_name-123",
expected: "my-database_name-123",
},
{
name: "preserves alphanumeric characters",
input: "Database123ABC",
expected: "Database123ABC",
},
{
name: "preserves dots and parentheses",
input: "db.production.(v2)",
expected: "db.production.(v2)",
},
{
name: "handles unicode characters",
input: "базаанных_テスト",
expected: "базаанных_テスト",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_SanitizeFilename_WindowsReservedNames(t *testing.T) {
// Windows reserved names are case-insensitive: CON, PRN, AUX, NUL, COM1-COM9, LPT1-LPT9
// Our function doesn't handle these specifically because:
// 1. Database names in our system are typically lowercase
// 2. These are combined with timestamps and UUIDs in filenames (e.g., "CON-20240102-150405-uuid")
// 3. The timestamp and UUID suffix make the final filename safe on Windows
tests := []struct {
name string
input string
expected string
}{
{
name: "CON remains as CON (will be safe with timestamp suffix)",
input: "CON",
expected: "CON",
},
{
name: "PRN remains as PRN (will be safe with timestamp suffix)",
input: "PRN",
expected: "PRN",
},
{
name: "COM1 remains as COM1 (will be safe with timestamp suffix)",
input: "COM1",
expected: "COM1",
},
{
name: "handles database name with reserved name as part",
input: "my:CON/database",
expected: "my-CON-database",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_SanitizeFilename_RealWorldExamples(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "production database with environment",
input: "prod:main/db",
expected: "prod-main-db",
},
{
name: "database with spaces and version",
input: "My App Database v2.0",
expected: "My_App_Database_v2.0",
},
{
name: "database with special query chars",
input: "analytics?region=us*",
expected: "analytics-region=us-",
},
{
name: "windows-style path in database name",
input: "C:\\databases\\prod",
expected: "C--databases-prod",
},
{
name: "unix-style path in database name",
input: "/var/lib/postgres/main",
expected: "-var-lib-postgres-main",
},
{
name: "database name with quotes",
input: "\"production\" database",
expected: "-production-_database",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}

View File

@@ -2,24 +2,24 @@ package period
import "time"
type Period string
type TimePeriod string
const (
PeriodDay Period = "DAY"
PeriodWeek Period = "WEEK"
PeriodMonth Period = "MONTH"
Period3Month Period = "3_MONTH"
Period6Month Period = "6_MONTH"
PeriodYear Period = "YEAR"
Period2Years Period = "2_YEARS"
Period3Years Period = "3_YEARS"
Period4Years Period = "4_YEARS"
Period5Years Period = "5_YEARS"
PeriodForever Period = "FOREVER"
PeriodDay TimePeriod = "DAY"
PeriodWeek TimePeriod = "WEEK"
PeriodMonth TimePeriod = "MONTH"
Period3Month TimePeriod = "3_MONTH"
Period6Month TimePeriod = "6_MONTH"
PeriodYear TimePeriod = "YEAR"
Period2Years TimePeriod = "2_YEARS"
Period3Years TimePeriod = "3_YEARS"
Period4Years TimePeriod = "4_YEARS"
Period5Years TimePeriod = "5_YEARS"
PeriodForever TimePeriod = "FOREVER"
)
// ToDuration converts Period to time.Duration
func (p Period) ToDuration() time.Duration {
func (p TimePeriod) ToDuration() time.Duration {
switch p {
case PeriodDay:
return 24 * time.Hour
@@ -55,7 +55,7 @@ func (p Period) ToDuration() time.Duration {
// 1 if p > other
//
// FOREVER is treated as the longest period
func (p Period) CompareTo(other Period) int {
func (p TimePeriod) CompareTo(other TimePeriod) int {
if p == other {
return 0
}

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN IF NOT EXISTS is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
DROP COLUMN IF EXISTS is_exclude_events;
-- +goose StatementEnd

View File

@@ -0,0 +1,17 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backups ADD COLUMN file_name TEXT;
-- +goose StatementEnd
-- +goose StatementBegin
UPDATE backups SET file_name = id::TEXT WHERE file_name IS NULL;
-- +goose StatementEnd
-- +goose StatementBegin
ALTER TABLE backups ALTER COLUMN file_name SET NOT NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backups DROP COLUMN file_name;
-- +goose StatementEnd

View File

@@ -0,0 +1,38 @@
-- +goose Up
ALTER TABLE backup_configs
ADD COLUMN retention_policy_type TEXT NOT NULL DEFAULT 'TIME_PERIOD',
ADD COLUMN retention_time_period TEXT NOT NULL DEFAULT '',
ADD COLUMN retention_count INT NOT NULL DEFAULT 0,
ADD COLUMN retention_gfs_hours INT NOT NULL DEFAULT 0,
ADD COLUMN retention_gfs_days INT NOT NULL DEFAULT 0,
ADD COLUMN retention_gfs_weeks INT NOT NULL DEFAULT 0,
ADD COLUMN retention_gfs_months INT NOT NULL DEFAULT 0,
ADD COLUMN retention_gfs_years INT NOT NULL DEFAULT 0;
UPDATE backup_configs
SET retention_time_period = store_period;
ALTER TABLE backup_configs
DROP COLUMN store_period;
-- +goose Down
ALTER TABLE backup_configs
ADD COLUMN store_period TEXT NOT NULL DEFAULT 'WEEK';
UPDATE backup_configs
SET store_period = CASE
WHEN retention_time_period != '' THEN retention_time_period
ELSE 'WEEK'
END;
ALTER TABLE backup_configs
DROP COLUMN retention_policy_type,
DROP COLUMN retention_time_period,
DROP COLUMN retention_count,
DROP COLUMN retention_gfs_hours,
DROP COLUMN retention_gfs_days,
DROP COLUMN retention_gfs_weeks,
DROP COLUMN retention_gfs_months,
DROP COLUMN retention_gfs_years;

View File

@@ -0,0 +1,9 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mongodb_databases ADD COLUMN is_direct_connection BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mongodb_databases DROP COLUMN is_direct_connection;
-- +goose StatementEnd

View File

@@ -2,4 +2,5 @@ MODE=development
VITE_GITHUB_CLIENT_ID=
VITE_GOOGLE_CLIENT_ID=
VITE_IS_EMAIL_CONFIGURED=false
VITE_IS_CLOUD=false
VITE_IS_CLOUD=false
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=

View File

@@ -3,6 +3,7 @@ interface RuntimeConfig {
GITHUB_CLIENT_ID?: string;
GOOGLE_CLIENT_ID?: string;
IS_EMAIL_CONFIGURED?: string;
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
}
declare global {
@@ -39,6 +40,11 @@ export const IS_EMAIL_CONFIGURED =
window.__RUNTIME_CONFIG__?.IS_EMAIL_CONFIGURED === 'true' ||
import.meta.env.VITE_IS_EMAIL_CONFIGURED === 'true';
export const CLOUDFLARE_TURNSTILE_SITE_KEY =
window.__RUNTIME_CONFIG__?.CLOUDFLARE_TURNSTILE_SITE_KEY ||
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
'';
export function getOAuthRedirectUri(): string {
return `${window.location.origin}/auth/callback`;
}

View File

@@ -5,5 +5,6 @@ export type { Backup } from './model/Backup';
export type { BackupConfig } from './model/BackupConfig';
export { BackupNotificationType } from './model/BackupNotificationType';
export { BackupEncryption } from './model/BackupEncryption';
export { RetentionPolicyType } from './model/RetentionPolicyType';
export type { TransferDatabaseRequest } from './model/TransferDatabaseRequest';
export type { DatabasePlan } from '../plan';

View File

@@ -3,12 +3,22 @@ import type { Interval } from '../../intervals';
import type { Storage } from '../../storages';
import { BackupEncryption } from './BackupEncryption';
import type { BackupNotificationType } from './BackupNotificationType';
import type { RetentionPolicyType } from './RetentionPolicyType';
export interface BackupConfig {
databaseId: string;
isBackupsEnabled: boolean;
storePeriod: Period;
retentionPolicyType: RetentionPolicyType;
retentionTimePeriod: Period;
retentionCount: number;
retentionGfsHours: number;
retentionGfsDays: number;
retentionGfsWeeks: number;
retentionGfsMonths: number;
retentionGfsYears: number;
backupInterval?: Interval;
storage?: Storage;
sendNotificationsOn: BackupNotificationType[];

View File

@@ -0,0 +1,5 @@
export enum RetentionPolicyType {
TimePeriod = 'TIME_PERIOD',
Count = 'COUNT',
GFS = 'GFS',
}

View File

@@ -456,6 +456,46 @@ describe('MongodbConnectionStringParser', () => {
});
});
describe('Direct Connection Handling', () => {
it('should parse directConnection=true from URI', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'mongodb://user:pass@host:27017/db?authSource=admin&directConnection=true',
),
);
expect(result.isDirectConnection).toBe(true);
});
it('should default isDirectConnection to false when not specified in URI', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse('mongodb://user:pass@host:27017/db'),
);
expect(result.isDirectConnection).toBe(false);
});
it('should parse isDirectConnection=true from key-value format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'host=localhost port=27017 database=mydb user=admin password=secret directConnection=true',
),
);
expect(result.isDirectConnection).toBe(true);
});
it('should default isDirectConnection to false in key-value format when not specified', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'host=localhost port=27017 database=mydb user=admin password=secret',
),
);
expect(result.isDirectConnection).toBe(false);
});
});
describe('Password Placeholder Handling', () => {
it('should treat <db_password> placeholder as empty password in URI format', () => {
const result = expectSuccess(

View File

@@ -7,6 +7,7 @@ export type ParseResult = {
authDatabase: string;
useTls: boolean;
isSrv: boolean;
isDirectConnection: boolean;
};
export type ParseError = {
@@ -69,6 +70,7 @@ export class MongodbConnectionStringParser {
const database = decodeURIComponent(url.pathname.slice(1));
const authDatabase = this.getAuthSource(url.search) || 'admin';
const useTls = isSrv ? true : this.checkTlsMode(url.search);
const isDirectConnection = this.checkDirectConnection(url.search);
if (!host) {
return { error: 'Host is missing from connection string' };
@@ -87,6 +89,7 @@ export class MongodbConnectionStringParser {
authDatabase,
useTls,
isSrv,
isDirectConnection,
};
} catch (e) {
return {
@@ -133,6 +136,7 @@ export class MongodbConnectionStringParser {
}
const useTls = this.isTlsEnabled(tls);
const isDirectConnection = params['directConnection'] === 'true';
return {
host,
@@ -143,6 +147,7 @@ export class MongodbConnectionStringParser {
authDatabase,
useTls,
isSrv: false,
isDirectConnection,
};
} catch (e) {
return {
@@ -162,6 +167,16 @@ export class MongodbConnectionStringParser {
return params.get('authSource') || params.get('authDatabase') || undefined;
}
private static checkDirectConnection(queryString: string | undefined | null): boolean {
if (!queryString) return false;
const params = new URLSearchParams(
queryString.startsWith('?') ? queryString.slice(1) : queryString,
);
return params.get('directConnection') === 'true';
}
private static checkTlsMode(queryString: string | undefined | null): boolean {
if (!queryString) return false;

View File

@@ -11,5 +11,6 @@ export interface MongodbDatabase {
authDatabase: string;
isHttps: boolean;
isSrv: boolean;
isDirectConnection: boolean;
cpuCount: number;
}

View File

@@ -31,10 +31,18 @@ const notifyAuthListeners = () => {
};
export const userApi = {
async signUp(signUpRequest: SignUpRequest) {
async signUp(signUpRequest: SignUpRequest): Promise<SignInResponse> {
const requestOptions: RequestOptions = new RequestOptions();
requestOptions.setBody(JSON.stringify(signUpRequest));
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/users/signup`, requestOptions);
return apiHelper
.fetchPostJson(`${getApplicationServer()}/api/v1/users/signup`, requestOptions)
.then((response: unknown): SignInResponse => {
const typedResponse = response as SignInResponse;
saveAuthorizedData(typedResponse.token, typedResponse.userId);
notifyAuthListeners();
return typedResponse;
});
},
async signIn(signInRequest: SignInRequest): Promise<SignInResponse> {

View File

@@ -1,3 +1,4 @@
export interface SendResetPasswordCodeRequest {
email: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -1,4 +1,5 @@
export interface SignInRequest {
email: string;
password: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -2,4 +2,5 @@ export interface SignUpRequest {
email: string;
password: string;
name: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -56,8 +56,8 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
const [showingRestoresBackupId, setShowingRestoresBackupId] = useState<string | undefined>();
const isReloadInProgress = useRef(false);
const isLazyLoadInProgress = useRef(false);
const lastRequestTimeRef = useRef<number>(0);
const isBackupsRequestInFlightRef = useRef(false);
const [downloadingBackupId, setDownloadingBackupId] = useState<string | undefined>();
const [cancellingBackupId, setCancellingBackupId] = useState<string | undefined>();
@@ -73,85 +73,59 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
};
const loadBackups = async (limit?: number) => {
if (isReloadInProgress.current || isLazyLoadInProgress.current) {
return;
}
if (isBackupsRequestInFlightRef.current) return;
isBackupsRequestInFlightRef.current = true;
isReloadInProgress.current = true;
const requestTime = Date.now();
lastRequestTimeRef.current = requestTime;
const loadLimit = limit ?? currentLimit;
try {
const loadLimit = limit || currentLimit;
const response = await backupsApi.getBackups(database.id, loadLimit, 0);
if (lastRequestTimeRef.current !== requestTime) return;
setBackups(response.backups);
setTotalBackups(response.total);
setHasMore(response.backups.length < response.total);
} catch (e) {
alert((e as Error).message);
if (lastRequestTimeRef.current === requestTime) {
alert((e as Error).message);
}
} finally {
isBackupsRequestInFlightRef.current = false;
}
isReloadInProgress.current = false;
};
const reloadInProgressBackups = async () => {
if (isReloadInProgress.current || isLazyLoadInProgress.current) {
return;
}
isReloadInProgress.current = true;
try {
// Fetch only the recent backups that could be in progress
// We fetch a small number (20) to capture recent backups that might be in progress
const response = await backupsApi.getBackups(database.id, 20, 0);
// Update only the backups that exist in both lists
setBackups((prevBackups) => {
const updatedBackups = [...prevBackups];
response.backups.forEach((newBackup) => {
const index = updatedBackups.findIndex((b) => b.id === newBackup.id);
if (index !== -1) {
updatedBackups[index] = newBackup;
} else if (index === -1 && updatedBackups.length < currentLimit) {
// New backup that doesn't exist yet (e.g., just created)
updatedBackups.unshift(newBackup);
}
});
return updatedBackups;
});
setTotalBackups(response.total);
} catch (e) {
alert((e as Error).message);
}
isReloadInProgress.current = false;
};
const loadMoreBackups = async () => {
if (isLoadingMore || !hasMore || isLazyLoadInProgress.current) {
if (isLoadingMore || !hasMore) {
return;
}
isLazyLoadInProgress.current = true;
setIsLoadingMore(true);
const newLimit = currentLimit + BACKUPS_PAGE_SIZE;
setCurrentLimit(newLimit);
const requestTime = Date.now();
lastRequestTimeRef.current = requestTime;
try {
const newLimit = currentLimit + BACKUPS_PAGE_SIZE;
const response = await backupsApi.getBackups(database.id, newLimit, 0);
if (lastRequestTimeRef.current !== requestTime) return;
setBackups(response.backups);
setCurrentLimit(newLimit);
setTotalBackups(response.total);
setHasMore(response.backups.length < response.total);
} catch (e) {
alert((e as Error).message);
if (lastRequestTimeRef.current === requestTime) {
alert((e as Error).message);
}
}
setIsLoadingMore(false);
isLazyLoadInProgress.current = false;
};
const makeBackup = async () => {
@@ -196,7 +170,7 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
try {
await backupsApi.cancelBackup(backupId);
await reloadInProgressBackups();
await loadBackups();
} catch (e) {
alert((e as Error).message);
}
@@ -220,22 +194,13 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
return () => {};
}, [database]);
// Reload backups that are in progress to update their state
useEffect(() => {
const hasInProgressBackups = backups.some(
(backup) => backup.status === BackupStatus.IN_PROGRESS,
);
if (!hasInProgressBackups) {
return;
}
const timeoutId = setTimeout(async () => {
await reloadInProgressBackups();
const intervalId = setInterval(() => {
loadBackups();
}, 1_000);
return () => clearTimeout(timeoutId);
}, [backups]);
return () => clearInterval(intervalId);
}, [currentLimit]);
useEffect(() => {
if (downloadingBackupId) {

View File

@@ -20,6 +20,7 @@ import {
type BackupConfig,
BackupEncryption,
type DatabasePlan,
RetentionPolicyType,
backupConfigApi,
} from '../../../entity/backups';
import { BackupNotificationType } from '../../../entity/backups/model/BackupNotificationType';
@@ -64,6 +65,15 @@ const weekdayOptions = [
{ value: 7, label: 'Sun' },
];
const retentionPolicyOptions = [
{
label: 'GFS (keep last N hourly, daily, weekly, monthly and yearly backups)',
value: RetentionPolicyType.GFS,
},
{ label: 'Time period (last N days)', value: RetentionPolicyType.TimePeriod },
{ label: 'Count (N last backups)', value: RetentionPolicyType.Count },
];
export const EditBackupConfigComponent = ({
user,
database,
@@ -95,6 +105,7 @@ export const EditBackupConfigComponent = ({
(backupConfig?.maxBackupSizeMb ?? 0) > 0 ||
(backupConfig?.maxBackupsTotalSizeMb ?? 0) > 0;
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const [isShowGfsHint, setShowGfsHint] = useState(false);
const timeFormat = useMemo(() => {
const is12 = getIs12Hour();
@@ -242,8 +253,20 @@ export const EditBackupConfigComponent = ({
timeOfDay: '00:00',
},
storage: undefined,
storePeriod:
plan.maxStoragePeriod === Period.FOREVER ? Period.THREE_MONTH : plan.maxStoragePeriod,
retentionPolicyType: IS_CLOUD
? RetentionPolicyType.GFS
: RetentionPolicyType.TimePeriod,
retentionTimePeriod: IS_CLOUD
? plan.maxStoragePeriod === Period.FOREVER
? Period.THREE_MONTH
: plan.maxStoragePeriod
: Period.THREE_MONTH,
retentionCount: 100,
retentionGfsHours: 24,
retentionGfsDays: 7,
retentionGfsWeeks: 4,
retentionGfsMonths: 12,
retentionGfsYears: 3,
sendNotificationsOn: [BackupNotificationType.BackupFailed],
isRetryIfFailed: true,
maxFailedTriesCount: 3,
@@ -295,10 +318,27 @@ export const EditBackupConfigComponent = ({
? getLocalDayOfMonth(backupInterval.dayOfMonth, backupInterval.timeOfDay)
: backupInterval?.dayOfMonth;
// mandatory-field check
const retentionPolicyType = backupConfig.retentionPolicyType ?? RetentionPolicyType.TimePeriod;
const isRetentionValid = (() => {
switch (retentionPolicyType) {
case RetentionPolicyType.TimePeriod:
return Boolean(backupConfig.retentionTimePeriod);
case RetentionPolicyType.Count:
return (backupConfig.retentionCount ?? 0) > 0;
case RetentionPolicyType.GFS:
return (
(backupConfig.retentionGfsDays ?? 0) > 0 ||
(backupConfig.retentionGfsWeeks ?? 0) > 0 ||
(backupConfig.retentionGfsMonths ?? 0) > 0 ||
(backupConfig.retentionGfsYears ?? 0) > 0
);
}
})();
const isAllFieldsFilled =
!backupConfig.isBackupsEnabled ||
(Boolean(backupConfig.storePeriod) &&
(isRetentionValid &&
Boolean(backupConfig.storage?.id) &&
Boolean(backupConfig.encryption) &&
Boolean(backupInterval?.interval) &&
@@ -467,7 +507,7 @@ export const EditBackupConfigComponent = ({
</>
)}
<div className="mt-2 mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mt-5 mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Storage</div>
<div className="flex w-full items-center">
<Select
@@ -530,23 +570,160 @@ export const EditBackupConfigComponent = ({
</div>
)}
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[150px] sm:mb-0">Store period</div>
<div className="flex items-center">
<div className="mt-5 mb-1 flex w-full flex-col items-start sm:flex-row sm:items-start">
<div className="mt-1 mb-1 min-w-[150px] sm:mb-0">Retention policy</div>
<div className="flex flex-col gap-1">
<Select
value={backupConfig.storePeriod}
onChange={(v) => updateBackupConfig({ storePeriod: v })}
value={retentionPolicyType}
options={retentionPolicyOptions}
size="small"
className="w-[200px]"
options={availablePeriods}
popupMatchSelectWidth={false}
onChange={(v) => {
const type = v as RetentionPolicyType;
const updates: Partial<typeof backupConfig> = { retentionPolicyType: type };
if (type === RetentionPolicyType.GFS) {
updates.retentionGfsHours = 24;
updates.retentionGfsDays = 7;
updates.retentionGfsWeeks = 4;
updates.retentionGfsMonths = 12;
updates.retentionGfsYears = 3;
} else if (type === RetentionPolicyType.Count) {
updates.retentionCount = 100;
}
updateBackupConfig(updates);
}}
/>
<Tooltip
className="cursor-pointer"
title="How long to keep the backups? Make sure you have enough storage space."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
{retentionPolicyType === RetentionPolicyType.TimePeriod && (
<div className="flex items-center">
<Select
value={backupConfig.retentionTimePeriod}
onChange={(v) => updateBackupConfig({ retentionTimePeriod: v })}
size="small"
className="w-[200px]"
options={availablePeriods}
/>
<Tooltip
className="cursor-pointer"
title="How long to keep the backups. Backups older than this period are automatically deleted."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
)}
{retentionPolicyType === RetentionPolicyType.Count && (
<div className="flex items-center">
<InputNumber
min={1}
value={backupConfig.retentionCount}
onChange={(v) => updateBackupConfig({ retentionCount: v ?? 1 })}
size="small"
className="w-[80px]"
/>
<span className="ml-2 text-sm text-gray-600 dark:text-gray-400">
most recent backups
</span>
<Tooltip
className="cursor-pointer"
title="Keep only the specified number of most recent backups. Older backups beyond this count are automatically deleted."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
)}
{retentionPolicyType === RetentionPolicyType.GFS && (
<>
<div>
<span
className="cursor-pointer text-xs text-blue-600 hover:text-blue-800"
onClick={() => setShowGfsHint(!isShowGfsHint)}
>
{isShowGfsHint ? 'Hide' : 'What is GFS (Grandfather-Father-Son)?'}
</span>
{isShowGfsHint && (
<div className="mt-1 max-w-[280px] text-xs text-gray-600 dark:text-gray-400">
GFS (Grandfather-Father-Son) rotation: keep the last N hourly, daily, weekly,
monthly and yearly backups. This allows keeping backups over long periods of
time within a reasonable storage space.
</div>
)}
</div>
<div className="flex flex-col gap-1">
<div className="flex items-center gap-2">
<span className="w-[110px] text-sm text-gray-600 dark:text-gray-400">
Hourly backups
</span>
<InputNumber
min={0}
value={backupConfig.retentionGfsHours}
onChange={(v) => updateBackupConfig({ retentionGfsHours: v ?? 0 })}
size="small"
className="w-[80px]"
/>
</div>
<div className="flex items-center gap-2">
<span className="w-[110px] text-sm text-gray-600 dark:text-gray-400">
Daily backups
</span>
<InputNumber
min={0}
value={backupConfig.retentionGfsDays}
onChange={(v) => updateBackupConfig({ retentionGfsDays: v ?? 0 })}
size="small"
className="w-[80px]"
/>
</div>
<div className="flex items-center gap-2">
<span className="w-[110px] text-sm text-gray-600 dark:text-gray-400">
Weekly backups
</span>
<InputNumber
min={0}
value={backupConfig.retentionGfsWeeks}
onChange={(v) => updateBackupConfig({ retentionGfsWeeks: v ?? 0 })}
size="small"
className="w-[80px]"
/>
</div>
<div className="flex items-center gap-2">
<span className="w-[110px] text-sm text-gray-600 dark:text-gray-400">
Monthly backups
</span>
<InputNumber
min={0}
value={backupConfig.retentionGfsMonths}
onChange={(v) => updateBackupConfig({ retentionGfsMonths: v ?? 0 })}
size="small"
className="w-[80px]"
/>
</div>
<div className="flex items-center gap-2">
<span className="w-[110px] text-sm text-gray-600 dark:text-gray-400">
Yearly backups
</span>
<InputNumber
min={0}
value={backupConfig.retentionGfsYears}
onChange={(v) => updateBackupConfig({ retentionGfsYears: v ?? 0 })}
size="small"
className="w-[80px]"
/>
</div>
</div>
</>
)}
</div>
</div>

View File

@@ -6,7 +6,12 @@ import { useMemo } from 'react';
import { useEffect, useState } from 'react';
import { IS_CLOUD } from '../../../constants';
import { type BackupConfig, BackupEncryption, backupConfigApi } from '../../../entity/backups';
import {
type BackupConfig,
BackupEncryption,
RetentionPolicyType,
backupConfigApi,
} from '../../../entity/backups';
import { BackupNotificationType } from '../../../entity/backups/model/BackupNotificationType';
import type { Database } from '../../../entity/databases';
import { Period } from '../../../entity/databases/model/Period';
@@ -60,10 +65,21 @@ const notificationLabels = {
[BackupNotificationType.BackupSuccess]: 'Backup success',
};
const formatGfsRetention = (config: BackupConfig): string => {
const parts: string[] = [];
if (config.retentionGfsHours > 0) parts.push(`${config.retentionGfsHours} hourly`);
if (config.retentionGfsDays > 0) parts.push(`${config.retentionGfsDays} daily`);
if (config.retentionGfsWeeks > 0) parts.push(`${config.retentionGfsWeeks} weekly`);
if (config.retentionGfsMonths > 0) parts.push(`${config.retentionGfsMonths} monthly`);
if (config.retentionGfsYears > 0) parts.push(`${config.retentionGfsYears} yearly`);
return parts.length > 0 ? parts.join(', ') : 'Not configured';
};
export const ShowBackupConfigComponent = ({ database }: Props) => {
const [backupConfig, setBackupConfig] = useState<BackupConfig>();
// Detect user's preferred time format (12-hour vs 24-hour)
const timeFormat = useMemo(() => {
const is12Hour = getIs12Hour();
return {
@@ -92,7 +108,6 @@ export const ShowBackupConfigComponent = ({ database }: Props) => {
const formattedTime = localTime ? localTime.format(timeFormat.format) : '';
// Convert UTC weekday/day-of-month to local equivalents for display
const displayedWeekday: number | undefined =
backupInterval?.interval === IntervalType.WEEKLY &&
backupInterval.weekday &&
@@ -107,6 +122,8 @@ export const ShowBackupConfigComponent = ({ database }: Props) => {
? getLocalDayOfMonth(backupInterval.dayOfMonth, backupInterval.timeOfDay)
: backupInterval?.dayOfMonth;
const retentionPolicyType = backupConfig.retentionPolicyType ?? RetentionPolicyType.TimePeriod;
return (
<div>
<div className="mb-1 flex w-full items-center">
@@ -193,8 +210,27 @@ export const ShowBackupConfigComponent = ({ database }: Props) => {
)}
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Store period</div>
<div>{backupConfig.storePeriod ? periodLabels[backupConfig.storePeriod] : ''}</div>
<div className="min-w-[150px]">Retention policy</div>
<div className="flex items-center gap-1">
{retentionPolicyType === RetentionPolicyType.TimePeriod && (
<span>
{backupConfig.retentionTimePeriod
? periodLabels[backupConfig.retentionTimePeriod]
: ''}
</span>
)}
{retentionPolicyType === RetentionPolicyType.Count && (
<span>Keep last {backupConfig.retentionCount} backups</span>
)}
{retentionPolicyType === RetentionPolicyType.GFS && (
<span className="flex items-center gap-1">
{formatGfsRetention(backupConfig)}
<Tooltip title="Grandfather-Father-Son rotation: keep the last N hourly, daily, weekly, monthly and yearly backups.">
<InfoCircleOutlined style={{ color: 'gray' }} />
</Tooltip>
</span>
)}
</div>
</div>
<div className="mb-1 flex w-full items-center">

View File

@@ -46,7 +46,10 @@ export const EditMongoDbSpecificDataComponent = ({
const [isTestingConnection, setIsTestingConnection] = useState(false);
const [isConnectionFailed, setIsConnectionFailed] = useState(false);
const hasAdvancedValues = !!database.mongodb?.authDatabase || !!database.mongodb?.isSrv;
const hasAdvancedValues =
!!database.mongodb?.authDatabase ||
!!database.mongodb?.isSrv ||
!!database.mongodb?.isDirectConnection;
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const parseFromClipboard = async () => {
@@ -80,11 +83,12 @@ export const EditMongoDbSpecificDataComponent = ({
authDatabase: result.authDatabase,
isHttps: result.useTls,
isSrv: result.isSrv,
isDirectConnection: result.isDirectConnection,
cpuCount: 1,
},
};
if (result.isSrv) {
if (result.isSrv || result.isDirectConnection) {
setShowAdvanced(true);
}
@@ -407,6 +411,31 @@ export const EditMongoDbSpecificDataComponent = ({
</div>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Direct connection</div>
<div className="flex items-center">
<Switch
checked={editingDatabase.mongodb?.isDirectConnection || false}
onChange={(checked) => {
if (!editingDatabase.mongodb) return;
setEditingDatabase({
...editingDatabase,
mongodb: { ...editingDatabase.mongodb, isDirectConnection: checked },
});
setIsConnectionTested(false);
}}
size="small"
/>
<Tooltip
className="cursor-pointer"
title="Connect directly to a single server, skipping replica set discovery. Useful when the server is behind a load balancer, proxy or tunnel."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Auth database</div>
<Input

View File

@@ -42,6 +42,13 @@ export const ShowMongoDbSpecificDataComponent = ({ database }: Props) => {
<div>{database.mongodb?.cpuCount}</div>
</div>
{database.mongodb?.isDirectConnection && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Direct connection</div>
<div>Yes</div>
</div>
)}
{database.mongodb?.authDatabase && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Auth database</div>

View File

@@ -10,6 +10,9 @@ interface Props {
}
export function EditNASStorageComponent({ storage, setStorage, setUnsaved }: Props) {
const shareHasSlash =
storage?.nasStorage?.share?.includes('/') || storage?.nasStorage?.share?.includes('\\');
return (
<>
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
@@ -60,24 +63,33 @@ export function EditNASStorageComponent({ storage, setStorage, setUnsaved }: Pro
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="mb-1 min-w-[110px] sm:mb-0">Share</div>
<Input
value={storage?.nasStorage?.share || ''}
onChange={(e) => {
if (!storage?.nasStorage) return;
<div className="flex flex-col">
<Input
value={storage?.nasStorage?.share || ''}
onChange={(e) => {
if (!storage?.nasStorage) return;
setStorage({
...storage,
nasStorage: {
...storage.nasStorage,
share: e.target.value.trim(),
},
});
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
placeholder="shared_folder"
/>
setStorage({
...storage,
nasStorage: {
...storage.nasStorage,
share: e.target.value.trim(),
},
});
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
placeholder="shared_folder"
status={shareHasSlash ? 'warning' : undefined}
/>
{shareHasSlash && (
<div className="mt-1 max-w-[250px] text-xs text-yellow-600">
Share must be a single share name. Use the Path field for subdirectories (e.g. Share:
Databasus, Path: DB1)
</div>
)}
</div>
</div>
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">

View File

@@ -1,9 +1,12 @@
import { Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
interface RequestResetPasswordComponentProps {
onSwitchToSignIn?: () => void;
@@ -20,6 +23,8 @@ export function RequestResetPasswordComponent({
const [error, setError] = useState('');
const [successMessage, setSuccessMessage] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateEmail = (): boolean => {
if (!email) {
setEmailError(true);
@@ -42,7 +47,10 @@ export function RequestResetPasswordComponent({
setLoading(true);
try {
const response = await userApi.sendResetPasswordCode({ email });
const response = await userApi.sendResetPasswordCode({
email,
cloudflareTurnstileToken: token,
});
setSuccessMessage(response.message);
// After successful code send, switch to reset password form
@@ -53,6 +61,7 @@ export function RequestResetPasswordComponent({
}, 2000);
} catch (e) {
setError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
setLoading(false);
@@ -84,6 +93,8 @@ export function RequestResetPasswordComponent({
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
import { Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID, IS_EMAIL_CONFIGURED } from '../../../constants';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
@@ -29,6 +32,8 @@ export function SignInComponent({
const [signInError, setSignInError] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateFieldsForSignIn = (): boolean => {
if (!email) {
setEmailError(true);
@@ -59,9 +64,11 @@ export function SignInComponent({
await userApi.signIn({
email,
password,
cloudflareTurnstileToken: token,
});
} catch (e) {
setSignInError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
setLoading(false);
@@ -119,6 +126,8 @@ export function SignInComponent({
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
import { App, Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
@@ -31,6 +34,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
const [signUpError, setSignUpError] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateFieldsForSignUp = (): boolean => {
if (!name || name.trim() === '') {
setNameError(true);
@@ -85,10 +90,11 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
email,
password,
name,
cloudflareTurnstileToken: token,
});
await userApi.signIn({ email, password });
} catch (e) {
setSignUpError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
}
@@ -173,6 +179,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -1,7 +1,6 @@
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
import utc from 'dayjs/plugin/utc';
import { StrictMode } from 'react';
import { createRoot } from 'react-dom/client';
import './index.css';
@@ -11,8 +10,4 @@ import App from './App.tsx';
dayjs.extend(utc);
dayjs.extend(relativeTime);
createRoot(document.getElementById('root')!).render(
<StrictMode>
<App />
</StrictMode>,
);
createRoot(document.getElementById('root')!).render(<App />);

View File

@@ -0,0 +1,116 @@
import { useEffect, useRef, useState } from 'react';
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
declare global {
interface Window {
turnstile?: {
render: (
container: string | HTMLElement,
options: {
sitekey: string;
callback: (token: string) => void;
'error-callback'?: () => void;
'expired-callback'?: () => void;
theme?: 'light' | 'dark' | 'auto';
size?: 'normal' | 'compact' | 'flexible';
appearance?: 'always' | 'execute' | 'interaction-only';
},
) => string;
reset: (widgetId: string) => void;
remove: (widgetId: string) => void;
getResponse: (widgetId: string) => string | undefined;
};
}
}
interface UseCloudflareTurnstileReturn {
containerRef: React.RefObject<HTMLDivElement | null>;
token: string | undefined;
resetCloudflareTurnstile: () => void;
}
const loadCloudflareTurnstileScript = (): Promise<void> => {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
return Promise.resolve();
}
return new Promise((resolve, reject) => {
if (document.querySelector('script[src*="turnstile"]')) {
resolve();
return;
}
const script = document.createElement('script');
script.src = 'https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit';
script.async = true;
script.defer = true;
script.onload = () => resolve();
script.onerror = () => reject(new Error('Failed to load Cloudflare Turnstile'));
document.head.appendChild(script);
});
};
export function useCloudflareTurnstile(): UseCloudflareTurnstileReturn {
const [token, setToken] = useState<string | undefined>(undefined);
const containerRef = useRef<HTMLDivElement | null>(null);
const widgetIdRef = useRef<string | null>(null);
useEffect(() => {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY || !containerRef.current) {
return;
}
loadCloudflareTurnstileScript()
.then(() => {
if (!window.turnstile || !containerRef.current) {
return;
}
try {
const widgetId = window.turnstile.render(containerRef.current, {
sitekey: CLOUDFLARE_TURNSTILE_SITE_KEY,
callback: (receivedToken: string) => {
setToken(receivedToken);
},
'error-callback': () => {
setToken(undefined);
},
'expired-callback': () => {
setToken(undefined);
},
theme: 'auto',
size: 'normal',
appearance: 'execute',
});
widgetIdRef.current = widgetId;
} catch (error) {
console.error('Failed to render Cloudflare Turnstile widget:', error);
}
})
.catch((error) => {
console.error('Failed to load Cloudflare Turnstile:', error);
});
return () => {
if (widgetIdRef.current && window.turnstile) {
window.turnstile.remove(widgetIdRef.current);
widgetIdRef.current = null;
}
};
}, []);
const resetCloudflareTurnstile = () => {
if (widgetIdRef.current && window.turnstile) {
window.turnstile.reset(widgetIdRef.current);
setToken(undefined);
}
};
return {
containerRef,
token,
resetCloudflareTurnstile,
};
}

View File

@@ -0,0 +1,17 @@
import { type JSX } from 'react';
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
interface CloudflareTurnstileWidgetProps {
containerRef: React.RefObject<HTMLDivElement | null>;
}
export function CloudflareTurnstileWidget({
containerRef,
}: CloudflareTurnstileWidgetProps): JSX.Element | null {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
return null;
}
return <div ref={containerRef} className="mb-3" />;
}

View File

@@ -1,3 +1,4 @@
export { CloudflareTurnstileWidget } from './CloudflareTurnstileWidget';
export { ConfirmationComponent } from './ConfirmationComponent';
export { StarButtonComponent } from './StarButtonComponent';
export { ThemeToggleComponent } from './ThemeToggleComponent';