Compare commits

...

39 Commits

Author SHA1 Message Date
Rostislav Dugin
a91ee50e31 Merge pull request #221 from databasus/develop
Develop
2026-01-05 21:08:50 +03:00
Rostislav Dugin
7e5562b115 FEATURE (mysql): Add automatic detection of allowed privileges to backup proper DB items 2026-01-05 21:07:53 +03:00
Rostislav Dugin
3ef51c4d68 FEATURE (databases): Imrove check for required permissions to backup, check for read-only user and extend DBs models tests 2026-01-05 21:07:53 +03:00
github-actions[bot]
e47e513460 Update CITATION.cff to v2.20.3 2026-01-04 21:28:24 +00:00
Rostislav Dugin
226a6c06e6 Merge pull request #216 from databasus/develop
FIX (readonly user): Improve complexity of readonly user passwords to…
2026-01-05 00:07:37 +03:00
Rostislav Dugin
615fd9d574 FIX (readonly user): Improve complexity of readonly user passwords to pass Google Cloud requirements 2026-01-05 00:06:24 +03:00
github-actions[bot]
e9fcf20cdf Update CITATION.cff to v2.20.2 2026-01-04 20:14:38 +00:00
Rostislav Dugin
7649f4acfd Merge pull request #214 from databasus/develop
FIX (databases): Add timeout for deletion in case of storage stuck
2026-01-04 22:54:13 +03:00
Rostislav Dugin
7e4c3bcc19 FIX (databases): Add timeout for deletion in case of storage stuck 2026-01-04 22:51:11 +03:00
Rostislav Dugin
f2aecc0427 Merge pull request #212 from databasus/develop
FIX (mariadb): Add events exclusion for MariaDB
2026-01-04 22:16:24 +03:00
Rostislav Dugin
3ce7da319f FIX (mariadb): Add events exclusion for MariaDB 2026-01-04 22:15:31 +03:00
github-actions[bot]
096098f660 Update CITATION.cff to v2.20.1 2026-01-04 18:13:01 +00:00
Rostislav Dugin
c3ba4a7c5a Merge pull request #209 from databasus/develop
FIX (backups): Escape password over connection check to allow whitesp…
2026-01-04 20:50:08 +03:00
Rostislav Dugin
52c0f53608 FIX (backups): Escape password over connection check to allow whitespaces 2026-01-04 20:49:22 +03:00
github-actions[bot]
a5095acad4 Update CITATION.cff to v2.20.0 2026-01-04 14:54:48 +00:00
Rostislav Dugin
a6d32b5c09 Merge pull request #208 from databasus/develop
FIX (tests): Use unique DB names for PostgreSQL parallel tests
2026-01-04 17:26:52 +03:00
Rostislav Dugin
722560e824 FIX (tests): Use unique DB names for PostgreSQL parallel tests 2026-01-04 17:24:49 +03:00
Rostislav Dugin
496ac6120c Merge pull request #207 from databasus/develop
Develop
2026-01-04 16:24:22 +03:00
Rostislav Dugin
756c6c87af FIX (password): Trim db password at the moment of save and test connection instead right on the moment of input 2026-01-04 16:20:37 +03:00
Rostislav Dugin
a23d05b735 FIX (backups): Allow to make manual backups when scheduled are disabled 2026-01-04 16:11:14 +03:00
Rostislav Dugin
33a8d302eb FEATURE (workspaces): Add tranasfer between databases, storages and notifiers 2026-01-04 15:59:21 +03:00
github-actions[bot]
25ed1ffd2a Update CITATION.cff to v2.19.2 2026-01-02 13:30:15 +00:00
Rostislav Dugin
67582325bb Merge pull request #204 from databasus/develop
FIX (restores): Restore via stream instead of downloading backup to l…
2026-01-02 16:09:21 +03:00
Rostislav Dugin
5a89558cf6 FIX (restores): Restore via stream instead of downloading backup to local storage 2026-01-02 16:06:46 +03:00
github-actions[bot]
0ec02430b7 Update CITATION.cff to v2.19.1 2026-01-02 11:43:51 +00:00
Rostislav Dugin
49115684a7 Merge pull request #203 from databasus/develop
FIX (backups): Revert directory update
2026-01-02 14:23:27 +03:00
Rostislav Dugin
58ae86ff7a FIX (backups): Revert directory update 2026-01-02 14:20:32 +03:00
github-actions[bot]
82939bb079 Update CITATION.cff to v2.19.0 2026-01-02 09:55:59 +00:00
Rostislav Dugin
1697bfbae8 Merge pull request #202 from databasus/develop
Develop
2026-01-02 12:34:58 +03:00
Rostislav Dugin
205cb1ec02 FEATURE (restores): Validate there is enough disk space on restore 2026-01-02 12:33:31 +03:00
Rostislav Dugin
b9668875ef FIX (mongodb): Fix MongoDB build for ARM 2026-01-02 12:21:02 +03:00
Rostislav Dugin
ca3f0281a3 FIX (temp folders): Improve temp folders cleanup over backups and restores 2026-01-02 12:09:43 +03:00
Rostislav Dugin
1b8d783d4e FIX (temp): Add NAS temp directory to .gitignore 2026-01-02 11:50:08 +03:00
Rostislav Dugin
75b0477874 FIX (temp): Remove temp directory for NAS 2026-01-02 11:49:26 +03:00
Rostislav Dugin
19533514c2 FEATURE (postgresql): Move to directory format to speed up parallel backups 2026-01-02 11:46:15 +03:00
github-actions[bot]
b3c3ef136f Update CITATION.cff to v2.18.6 2026-01-01 19:11:46 +00:00
Rostislav Dugin
4a2ada384e Merge pull request #196 from databasus/develop
FIX (assets): Add square logos
2026-01-01 21:51:30 +03:00
Rostislav Dugin
b4fc0cfb56 FIX (assets): Add square logos 2026-01-01 21:51:04 +03:00
github-actions[bot]
a8fca1943b Update CITATION.cff to v2.18.5 2025-12-30 15:37:44 +00:00
133 changed files with 7604 additions and 1451 deletions

View File

@@ -32,5 +32,5 @@ keywords:
- mongodb
- mariadb
license: Apache-2.0
version: 2.18.4
date-released: "2025-12-30"
version: 2.20.3
date-released: "2026-01-04"

View File

@@ -172,19 +172,23 @@ RUN if [ "$TARGETARCH" = "amd64" ]; then \
# ========= Install MongoDB Database Tools =========
# Note: MongoDB Database Tools are backward compatible - single version supports all server versions (4.0-8.0)
# Use dpkg with apt-get -f install to handle dependencies
# Note: For ARM64, we use Ubuntu 22.04 package as MongoDB doesn't provide Debian 12 ARM64 packages
RUN apt-get update && \
if [ "$TARGETARCH" = "amd64" ]; then \
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
elif [ "$TARGETARCH" = "arm64" ]; then \
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-aarch64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
wget -q https://fastdl.mongodb.org/tools/db/mongodb-database-tools-ubuntu2204-arm64-100.10.0.deb -O /tmp/mongodb-database-tools.deb; \
fi && \
dpkg -i /tmp/mongodb-database-tools.deb || true && \
apt-get install -f -y --no-install-recommends && \
rm /tmp/mongodb-database-tools.deb && \
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends && \
rm -f /tmp/mongodb-database-tools.deb && \
rm -rf /var/lib/apt/lists/* && \
ln -sf /usr/bin/mongodump /usr/local/mongodb-database-tools/bin/mongodump && \
ln -sf /usr/bin/mongorestore /usr/local/mongodb-database-tools/bin/mongorestore
mkdir -p /usr/local/mongodb-database-tools/bin && \
if [ -f /usr/bin/mongodump ]; then \
ln -sf /usr/bin/mongodump /usr/local/mongodb-database-tools/bin/mongodump; \
fi && \
if [ -f /usr/bin/mongorestore ]; then \
ln -sf /usr/bin/mongorestore /usr/local/mongodb-database-tools/bin/mongorestore; \
fi
# Create postgres user and set up directories
RUN useradd -m -s /bin/bash postgres || true && \

BIN
assets/logo-square.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

12
assets/logo-square.svg Normal file
View File

@@ -0,0 +1,12 @@
<svg width="128" height="128" viewBox="0 0 128 128" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_287_1020)">
<path d="M50.1522 115.189C50.1522 121.189 57.1564 121.193 59 118C60.1547 116 61 114 61 108C61 102 58.1044 96.9536 55.3194 91.5175C54.6026 90.1184 53.8323 88.6149 53.0128 86.9234C51.6073 84.0225 49.8868 81.3469 47.3885 79.2139C47.0053 78.8867 46.8935 78.0093 46.9624 77.422C47.2351 75.1036 47.5317 72.7876 47.8283 70.4718C48.3186 66.6436 48.8088 62.8156 49.1909 58.9766C49.459 56.2872 49.4542 53.5119 49.1156 50.8329C48.3833 45.0344 45.1292 40.7783 40.1351 37.9114C38.6818 37.0771 38.2533 36.1455 38.4347 34.5853C38.9402 30.2473 40.6551 26.3306 42.8342 22.6642C44.8356 19.297 47.1037 16.0858 49.3676 12.8804C49.6576 12.4699 49.9475 12.0594 50.2367 11.6488C50.6069 11.1231 51.5231 10.7245 52.1971 10.7075C60.4129 10.5017 68.6303 10.3648 76.8477 10.2636C77.4123 10.2563 78.1584 10.5196 78.5221 10.9246C83.6483 16.634 88.2284 22.712 90.9778 29.9784C91.1658 30.4758 91.3221 30.9869 91.4655 31.4997C92.4976 35.1683 92.4804 35.1803 89.5401 37.2499L89.4071 37.3436C83.8702 41.2433 81.8458 46.8198 82.0921 53.349C82.374 60.8552 84.0622 68.1313 85.9869 75.3539C86.3782 76.8218 86.6318 77.9073 85.2206 79.2609C82.3951 81.9698 81.2196 85.6872 80.6575 89.4687C80.0724 93.4081 79.599 97.3637 79.1254 101.32C78.8627 103.515 78.8497 105.368 78.318 107.904C76.2819 117.611 71 128 63 128H50.1522C45 128 41 123.189 41 115.189H50.1522Z" fill="#155DFC"/>
<path d="M46.2429 6.56033C43.3387 11.1 40.3642 15.4031 37.7614 19.9209C35.413 23.9964 33.8487 28.4226 33.0913 33.1211C32.0998 39.2728 33.694 44.7189 38.0765 48.9775C41.6846 52.4835 42.6153 56.4472 42.152 61.1675C41.1426 71.4587 39.1174 81.5401 36.2052 91.4522C36.1769 91.5477 36.0886 91.6255 35.8974 91.8977C34.1517 91.3525 32.3161 90.8446 30.5266 90.2095C5.53011 81.3376 -12.7225 64.953 -24.1842 41.0298C-25.175 38.9625 -26.079 36.8498 -26.9263 34.7202C-27.0875 34.3151 -26.9749 33.5294 -26.6785 33.2531C-17.1479 24.3723 -7.64007 15.4647 2.00468 6.70938C8.64568 0.681612 16.5812 -1.21558 25.2457 0.739942C31.9378 2.24992 38.5131 4.27834 45.1363 6.09048C45.5843 6.2128 45.9998 6.45502 46.2429 6.56033Z" fill="#155DFC"/>
<path d="M96.9586 89.3257C95.5888 84.7456 94.0796 80.4011 93.0111 75.9514C91.6065 70.0978 90.4683 64.1753 89.3739 58.2529C88.755 54.9056 89.3998 51.8176 91.89 49.2108C98.2669 42.5358 98.3933 34.7971 95.3312 26.7037C92.7471 19.8739 88.593 13.9904 83.7026 8.60904C83.1298 7.9788 82.5693 7.33641 81.918 6.60491C82.2874 6.40239 82.5709 6.18773 82.8909 6.07999C90.1281 3.64085 97.4495 1.54842 105.041 0.488845C112.781 -0.591795 119.379 1.81818 125.045 6.97592C130.017 11.5018 134.805 16.2327 139.812 20.7188C143.822 24.3115 148.013 27.7066 152.19 31.1073C152.945 31.7205 153.137 32.2154 152.913 33.1041C149.059 48.4591 141.312 61.4883 129.457 71.9877C120.113 80.2626 109.35 85.9785 96.9586 89.3265V89.3257Z" fill="#155DFC"/>
</g>
<defs>
<clipPath id="clip0_287_1020">
<rect width="128" height="128" rx="6" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 3.0 KiB

17
assets/tools/README.md Normal file
View File

@@ -0,0 +1,17 @@
We keep binaries here to speed up CI \ CD tasks and building.
Docker image needs:
- PostgreSQL client tools (versions 12-18)
- MySQL client tools (versions 5.7, 8.0, 8.4, 9)
- MariaDB client tools (versions 10.6, 12.1)
- MongoDB Database Tools (latest)
For the most of tools, we need a couple of binaries for each version. However, if we download them on each run, it will download a couple of GBs each time.
So, for speed up we keep only required executables (like pg_dump, mysqldump, mariadb-dump, mongodump, etc.).
It takes:
- ~ 100MB for ARM
- ~ 100MB for x64
Instead of GBs. See Dockefile for usage details.

3
backend/.gitignore vendored
View File

@@ -16,4 +16,5 @@ databasus-backend.exe
ui/build/*
pgdata-for-restore/
temp/
cmd.exe
cmd.exe
temp/

View File

@@ -217,6 +217,7 @@ func setUpDependencies() {
audit_logs.SetupDependencies()
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {

View File

@@ -1,6 +1,7 @@
package audit_logs
import (
"errors"
"net/http"
user_models "databasus-backend/internal/features/users/models"
@@ -50,7 +51,7 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
if errors.Is(err, ErrOnlyAdminsCanViewGlobalLogs) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -99,7 +100,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
if errors.Is(err, ErrInsufficientPermissionsToViewLogs) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}

View File

@@ -0,0 +1,12 @@
package audit_logs
import "errors"
var (
ErrOnlyAdminsCanViewGlobalLogs = errors.New(
"only administrators can view global audit logs",
)
ErrInsufficientPermissionsToViewLogs = errors.New(
"insufficient permissions to view user audit logs",
)
)

View File

@@ -1,7 +1,6 @@
package audit_logs
import (
"errors"
"log/slog"
"time"
@@ -44,7 +43,7 @@ func (s *AuditLogService) GetGlobalAuditLogs(
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
return nil, ErrOnlyAdminsCanViewGlobalLogs
}
limit := request.Limit
@@ -79,7 +78,7 @@ func (s *AuditLogService) GetUserAuditLogs(
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
return nil, ErrInsufficientPermissionsToViewLogs
}
limit := request.Limit

View File

@@ -25,6 +25,20 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
@@ -54,24 +68,13 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2)
// cleanup
for _, backup := range backups {
err := backupRepository.DeleteByID(backup.ID)
assert.NoError(t, err)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
@@ -83,6 +86,20 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
@@ -118,18 +135,6 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no new backup created
// cleanup
for _, backup := range backups {
err := backupRepository.DeleteByID(backup.ID)
assert.NoError(t, err)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
@@ -141,6 +146,20 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
@@ -180,18 +199,6 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1) // Should still be 1 backup, no retry attempted
// cleanup
for _, backup := range backups {
err := backupRepository.DeleteByID(backup.ID)
assert.NoError(t, err)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
@@ -203,6 +210,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
@@ -236,24 +257,13 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Wait for backup to complete (runs in goroutine)
WaitForBackupCompletion(t, database.ID, 1, 10*time.Second)
// assertions
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 2) // Should have 2 backups, retry was attempted
// cleanup
for _, backup := range backups {
err := backupRepository.DeleteByID(backup.ID)
assert.NoError(t, err)
}
databases.RemoveTestDatabase(database)
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
@@ -265,6 +275,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
@@ -306,16 +330,60 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 3) // Should have 3 backups, not more than max
// cleanup
for _, backup := range backups {
err := backupRepository.DeleteByID(backup.ID)
assert.NoError(t, err)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackgroundBackupWhenBakupsDisabled_BackupSkipped(t *testing.T) {
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// add old backup that would trigger new backup if enabled
backupRepository.Save(&Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
CreatedAt: time.Now().UTC().Add(-24 * time.Hour),
})
GetBackupBackgroundService().runPendingBackups()
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
}

View File

@@ -0,0 +1,17 @@
package common
import backups_config "databasus-backend/internal/features/backups/config"
type BackupType string
const (
BackupTypeDefault BackupType = "DEFAULT" // For MySQL, MongoDB, PostgreSQL legacy (-Fc)
BackupTypeDirectory BackupType = "DIRECTORY" // PostgreSQL directory type (-Fd)
)
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
Type BackupType
}

View File

@@ -182,7 +182,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
fileReader, dbType, err := c.backupService.GetBackupFile(user, id)
fileReader, backup, database, err := c.backupService.GetBackupFile(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -193,15 +193,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
}
}()
extension := ".dump"
if dbType == databases.DatabaseTypeMysql || dbType == databases.DatabaseTypeMariadb {
extension = ".sql.zst"
}
filename := c.generateBackupFilename(backup, database)
ctx.Header("Content-Type", "application/octet-stream")
ctx.Header(
"Content-Disposition",
fmt.Sprintf("attachment; filename=\"backup_%s%s\"", id.String(), extension),
fmt.Sprintf("attachment; filename=\"%s\"", filename),
)
_, err = io.Copy(ctx.Writer, fileReader)
@@ -214,3 +211,62 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
type MakeBackupRequest struct {
DatabaseID uuid.UUID `json:"database_id" binding:"required"`
}
func (c *BackupController) generateBackupFilename(
backup *Backup,
database *databases.Database,
) string {
// Format timestamp as YYYY-MM-DD_HH-mm-ss
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
// Sanitize database name for filename (replace spaces and special chars)
safeName := sanitizeFilename(database.Name)
// Determine extension based on database type
extension := c.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}
func (c *BackupController) getBackupExtension(
dbType databases.DatabaseType,
) string {
switch dbType {
case databases.DatabaseTypeMysql, databases.DatabaseTypeMariadb:
return ".sql.zst"
case databases.DatabaseTypePostgres:
// PostgreSQL custom format
return ".dump"
case databases.DatabaseTypeMongodb:
return ".archive"
default:
return ".backup"
}
}
func sanitizeFilename(name string) string {
// Replace characters that are invalid in filenames
replacer := map[rune]rune{
' ': '_',
'/': '-',
'\\': '-',
':': '-',
'*': '-',
'?': '-',
'"': '-',
'<': '-',
'>': '-',
'|': '-',
}
result := make([]rune, 0, len(name))
for _, char := range name {
if replacement, exists := replacer[char]; exists {
result = append(result, replacement)
} else {
result = append(result, char)
}
}
return string(result)
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,6 +16,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -494,6 +496,112 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
assert.True(t, found, "Audit log for backup download not found")
}
func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
tests := []struct {
name string
databaseName string
expectedExt string
expectedInName string
}{
{
name: "PostgreSQL database",
databaseName: "my_postgres_db",
expectedExt: ".dump",
expectedInName: "my_postgres_db_backup_",
},
{
name: "Database name with spaces",
databaseName: "my test db",
expectedExt: ".dump",
expectedInName: "my_test_db_backup_",
},
{
name: "Database name with special characters",
databaseName: "my:db/test",
expectedExt: ".dump",
expectedInName: "my-db-test_backup_",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase(tt.databaseName, workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := createTestBackup(database, owner)
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
contentDisposition := resp.Headers.Get("Content-Disposition")
assert.NotEmpty(t, contentDisposition, "Content-Disposition header should be present")
// Verify the filename contains expected parts
assert.Contains(
t,
contentDisposition,
tt.expectedInName,
"Filename should contain sanitized database name",
)
assert.Contains(
t,
contentDisposition,
tt.expectedExt,
"Filename should have correct extension",
)
assert.Contains(t, contentDisposition, "attachment", "Should be an attachment")
// Verify timestamp format (YYYY-MM-DD_HH-mm-ss)
assert.Regexp(
t,
`\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}`,
contentDisposition,
"Filename should contain timestamp",
)
})
}
}
func Test_SanitizeFilename(t *testing.T) {
tests := []struct {
input string
expected string
}{
{input: "simple_name", expected: "simple_name"},
{input: "name with spaces", expected: "name_with_spaces"},
{input: "name/with\\slashes", expected: "name-with-slashes"},
{input: "name:with*special?chars", expected: "name-with-special-chars"},
{input: "name<with>pipes|", expected: "name-with-pipes-"},
{input: `name"with"quotes`, expected: "name-with-quotes"},
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
result := sanitizeFilename(tt.input)
assert.Equal(t, tt.expected, result)
})
}
}
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -573,7 +681,13 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -581,9 +695,9 @@ func createTestDatabase(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},

View File

@@ -3,7 +3,7 @@ package backups
import (
"context"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
usecases_common "databasus-backend/internal/features/backups/backups/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"

View File

@@ -214,11 +214,6 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
return
}
if !backupConfig.IsBackupsEnabled {
s.logger.Info("Backups are not enabled for this database")
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is not defined")
return
@@ -502,19 +497,19 @@ func (s *BackupService) CancelBackup(
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
) (io.ReadCloser, databases.DatabaseType, error) {
) (io.ReadCloser, *Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, "", err
return nil, nil, nil, err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, "", err
return nil, nil, nil, err
}
if database.WorkspaceID == nil {
return nil, "", errors.New("cannot download backup for database without workspace")
return nil, nil, nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
@@ -522,10 +517,12 @@ func (s *BackupService) GetBackupFile(
user,
)
if err != nil {
return nil, "", err
return nil, nil, nil, err
}
if !canAccess {
return nil, "", errors.New("insufficient permissions to download backup for this database")
return nil, nil, nil, errors.New(
"insufficient permissions to download backup for this database",
)
}
s.auditLogService.WriteAuditLog(
@@ -540,10 +537,10 @@ func (s *BackupService) GetBackupFile(
reader, err := s.getBackupReader(backupID)
if err != nil {
return nil, "", err
return nil, nil, nil, err
}
return reader, database.Type, nil
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *Backup) error {

View File

@@ -7,7 +7,7 @@ import (
"testing"
"time"
"databasus-backend/internal/features/backups/backups/usecases/common"
common "databasus-backend/internal/features/backups/backups/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
encryption_secrets "databasus-backend/internal/features/encryption/secrets"

View File

@@ -1,12 +1,16 @@
package backups
import (
"testing"
"time"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
func CreateTestRouter() *gin.Engine {
@@ -18,3 +22,49 @@ func CreateTestRouter() *gin.Engine {
GetBackupController(),
)
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)
// for the given database. It checks for backups with count greater than expectedInitialCount.
func WaitForBackupCompletion(
t *testing.T,
databaseID uuid.UUID,
expectedInitialCount int,
timeout time.Duration,
) {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
backups, err := backupRepository.FindByDatabaseID(databaseID)
if err != nil {
t.Logf("WaitForBackupCompletion: error finding backups: %v", err)
time.Sleep(50 * time.Millisecond)
continue
}
t.Logf(
"WaitForBackupCompletion: found %d backups (expected > %d)",
len(backups),
expectedInitialCount,
)
if len(backups) > expectedInitialCount {
// Check if the newest backup has completed or failed
newestBackup := backups[0]
t.Logf("WaitForBackupCompletion: newest backup status: %s", newestBackup.Status)
if newestBackup.Status == BackupStatusCompleted ||
newestBackup.Status == BackupStatusFailed ||
newestBackup.Status == BackupStatusCanceled {
t.Logf(
"WaitForBackupCompletion: backup finished with status %s",
newestBackup.Status,
)
return
}
}
time.Sleep(50 * time.Millisecond)
}
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}

View File

@@ -1,9 +0,0 @@
package common
import backups_config "databasus-backend/internal/features/backups/config"
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
common "databasus-backend/internal/features/backups/backups/common"
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
@@ -30,7 +30,7 @@ func (uc *CreateBackupUsecase) Execute(
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
switch database.Type {
case databases.DatabaseTypePostgres:
return uc.CreatePostgresqlBackupUsecase.Execute(

View File

@@ -18,8 +18,8 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mariadbtypes "databasus-backend/internal/features/databases/databases/mariadb"
@@ -57,17 +57,13 @@ func (uc *CreateMariadbBackupUsecase) Execute(
db *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info(
"Creating MariaDB backup via mariadb-dump",
"databaseId", db.ID,
"storageId", storage.ID,
)
if !backupConfig.IsBackupsEnabled {
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
mdb := db.Mariadb
if mdb == nil {
return nil, fmt.Errorf("mariadb database configuration is required")
@@ -111,12 +107,17 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--user=" + mdb.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--quick",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, "--compress")
if mdb.IsHttps {
@@ -140,7 +141,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
mdbConfig *mariadbtypes.MariadbDatabase,
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info("Streaming MariaDB backup to storage", "mariadbBin", mariadbBin)
ctx, cancel := uc.createBackupContext(parentCtx)
@@ -196,7 +197,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
if err != nil {
return nil, fmt.Errorf("failed to create zstd writer: %w", err)
}
countingWriter := usecases_common.NewCountingWriter(zstdWriter)
countingWriter := common.NewCountingWriter(zstdWriter)
saveErrCh := make(chan error, 1)
go func() {
@@ -264,7 +265,7 @@ func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp("", "mycnf")
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -401,8 +402,8 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
metadata := usecases_common.BackupMetadata{}
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -15,8 +15,8 @@ import (
"github.com/google/uuid"
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mongodbtypes "databasus-backend/internal/features/databases/databases/mongodb"
@@ -51,17 +51,13 @@ func (uc *CreateMongodbBackupUsecase) Execute(
db *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info(
"Creating MongoDB backup via mongodump",
"databaseId", db.ID,
"storageId", storage.ID,
)
if !backupConfig.IsBackupsEnabled {
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
mdb := db.Mongodb
if mdb == nil {
return nil, fmt.Errorf("mongodb database configuration is required")
@@ -124,7 +120,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
args []string,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info("Streaming MongoDB backup to storage", "mongodumpBin", mongodumpBin)
ctx, cancel := uc.createBackupContext(parentCtx)
@@ -175,7 +171,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
return nil, err
}
countingWriter := usecases_common.NewCountingWriter(finalWriter)
countingWriter := common.NewCountingWriter(finalWriter)
saveErrCh := make(chan error, 1)
go func() {
@@ -264,8 +260,8 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
backupMetadata := usecases_common.BackupMetadata{
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
backupMetadata := common.BackupMetadata{
Encryption: backups_config.BackupEncryptionNone,
}

View File

@@ -18,8 +18,8 @@ import (
"github.com/klauspost/compress/zstd"
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
mysqltypes "databasus-backend/internal/features/databases/databases/mysql"
@@ -57,17 +57,13 @@ func (uc *CreateMysqlBackupUsecase) Execute(
db *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info(
"Creating MySQL backup via mysqldump",
"databaseId", db.ID,
"storageId", storage.ID,
)
if !backupConfig.IsBackupsEnabled {
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
my := db.Mysql
if my == nil {
return nil, fmt.Errorf("mysql database configuration is required")
@@ -109,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--user=" + my.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--set-gtid-purged=OFF",
"--quick",
"--verbose",
}
if my.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if my.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
if my.IsHttps {
@@ -155,7 +156,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
myConfig *mysqltypes.MysqlDatabase,
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info("Streaming MySQL backup to storage", "mysqlBin", mysqlBin)
ctx, cancel := uc.createBackupContext(parentCtx)
@@ -211,7 +212,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
if err != nil {
return nil, fmt.Errorf("failed to create zstd writer: %w", err)
}
countingWriter := usecases_common.NewCountingWriter(zstdWriter)
countingWriter := common.NewCountingWriter(zstdWriter)
saveErrCh := make(chan error, 1)
go func() {
@@ -279,7 +280,7 @@ func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp("", "mycnf")
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -414,8 +415,8 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
metadata := usecases_common.BackupMetadata{}
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -15,8 +15,8 @@ import (
"time"
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
usecases_common "databasus-backend/internal/features/backups/backups/usecases/common"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
pgtypes "databasus-backend/internal/features/databases/databases/postgresql"
@@ -60,7 +60,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info(
"Creating PostgreSQL backup via pg_dump custom format",
"databaseId",
@@ -69,10 +69,6 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
storage.ID,
)
if !backupConfig.IsBackupsEnabled {
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
pg := db.Postgresql
if pg == nil {
@@ -119,7 +115,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storage *storages.Storage,
db *databases.Database,
backupProgressListener func(completedMBs float64),
) (*usecases_common.BackupMetadata, error) {
) (*common.BackupMetadata, error) {
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
ctx, cancel := uc.createBackupContext(parentCtx)
@@ -171,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
return nil, err
}
countingWriter := usecases_common.NewCountingWriter(finalWriter)
countingWriter := common.NewCountingWriter(finalWriter)
// The backup ID becomes the object key / filename in storage
@@ -335,11 +331,6 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlD
"--verbose",
}
// Add parallel jobs based on CPU count
if pg.CpuCount > 1 {
args = append(args, "-j", strconv.Itoa(pg.CpuCount))
}
for _, schema := range pg.IncludeSchemas {
args = append(args, "-n", schema)
}
@@ -476,8 +467,8 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
metadata := usecases_common.BackupMetadata{}
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -1,9 +1,11 @@
package backups_config
import (
users_middleware "databasus-backend/internal/features/users/middleware"
"errors"
"net/http"
users_middleware "databasus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
@@ -16,6 +18,8 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
router.POST("/backup-configs/database/:id/transfer", c.TransferDatabase)
}
// SaveBackupConfig
@@ -120,3 +124,86 @@ func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"isUsing": isUsing})
}
// CountDatabasesForStorage
// @Summary Count databases using a storage
// @Description Get the count of databases that are using a specific storage
// @Tags backup-configs
// @Produce json
// @Param id path string true "Storage ID"
// @Success 200 {object} map[string]int
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backup-configs/storage/{id}/databases-count [get]
func (c *BackupConfigController) CountDatabasesForStorage(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
return
}
count, err := c.backupConfigService.CountDatabasesForStorage(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"count": count})
}
// TransferDatabase
// @Summary Transfer database to another workspace
// @Description Transfer a database from one workspace to another. Can transfer to a new storage or transfer with the existing storage. Can also specify target notifiers from the target workspace.
// @Tags backup-configs
// @Accept json
// @Produce json
// @Param id path string true "Database ID"
// @Param request body TransferDatabaseRequest true "Transfer request with targetWorkspaceId, storage options (targetStorageId or isTransferWithStorage), and optional targetNotifierIds"
// @Success 200 {object} map[string]string "Database transferred successfully"
// @Failure 400 {object} map[string]string "Invalid request, target storage/notifier not in target workspace, or transfer failed"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 403 {object} map[string]string "Insufficient permissions"
// @Router /backup-configs/database/{id}/transfer [post]
func (c *BackupConfigController) TransferDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
var request TransferDatabaseRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.TargetWorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "targetWorkspaceId is required"})
return
}
if err := c.backupConfigService.TransferDatabaseToWorkspace(user, id, &request); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "database transferred successfully"})
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,6 +2,7 @@ package backups_config
import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
)
@@ -11,6 +12,7 @@ var backupConfigService = &BackupConfigService{
backupConfigRepository,
databases.GetDatabaseService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
nil,
}
@@ -25,3 +27,7 @@ func GetBackupConfigController() *BackupConfigController {
func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
func SetupDependencies() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
}

View File

@@ -0,0 +1,11 @@
package backups_config
import "github.com/google/uuid"
type TransferDatabaseRequest struct {
TargetWorkspaceID uuid.UUID `json:"targetWorkspaceId" binding:"required"`
TargetStorageID *uuid.UUID `json:"targetStorageId,omitempty"`
IsTransferWithStorage bool `json:"isTransferWithStorage,omitempty"`
IsTransferWithNotifiers bool `json:"isTransferWithNotifiers,omitempty"`
TargetNotifierIDs []uuid.UUID `json:"targetNotifierIds,omitempty"`
}

View File

@@ -0,0 +1,30 @@
package backups_config
import "errors"
var (
ErrInsufficientPermissionsInSourceWorkspace = errors.New(
"insufficient permissions to manage database in source workspace",
)
ErrInsufficientPermissionsInTargetWorkspace = errors.New(
"insufficient permissions to manage database in target workspace",
)
ErrTargetStorageNotInTargetWorkspace = errors.New(
"target storage does not belong to target workspace",
)
ErrTargetNotifierNotInTargetWorkspace = errors.New(
"target notifier does not belong to target workspace",
)
ErrStorageHasOtherAttachedDatabases = errors.New(
"storage has other attached databases and cannot be transferred with this database",
)
ErrDatabaseHasNoStorage = errors.New(
"database has no storage attached",
)
ErrDatabaseHasNoWorkspace = errors.New(
"database has no workspace",
)
ErrTargetStorageNotSpecified = errors.New(
"target storage is not specified",
)
)

View File

@@ -0,0 +1,166 @@
package backups_config
import (
"fmt"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
)
func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
router := createTestRouterWithNotifier()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.ID)
assert.Len(t, response.Notifiers, 1)
assert.Equal(t, notifier.ID, response.Notifiers[0].ID)
}
func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
router := createTestRouterWithNotifier()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
database := createTestDatabaseViaAPI("Test Database", workspace1.ID, owner1.Token, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
notifier := notifiers.CreateTestNotifier(workspace2.ID)
database.Notifiers = []notifiers.Notifier{*notifier}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner1.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "notifier does not belong to this workspace")
}
func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
router := createTestRouterWithNotifier()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusOK,
&response,
)
testResp := test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifier.ID.String()),
"Bearer "+owner.Token,
http.StatusBadRequest,
)
assert.Contains(
t,
string(testResp.Body),
"notifier has attached databases and cannot be deleted",
)
}
func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
router := createTestRouterWithNotifier()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
targetWorkspace := workspaces_testing.CreateTestWorkspace("Target Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
database,
http.StatusOK,
&response,
)
transferRequest := notifiers.TransferNotifierRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s/transfer", notifier.ID.String()),
"Bearer "+owner.Token,
transferRequest,
http.StatusBadRequest,
)
assert.Contains(
t,
string(testResp.Body),
"notifier has attached databases and cannot be transferred",
)
}
func createTestRouterWithNotifier() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetBackupConfigController(),
storages.GetStorageController(),
notifiers.GetNotifierController(),
)
storages.SetupDependencies()
databases.SetupDependencies()
notifiers.SetupDependencies()
SetupDependencies()
return router
}

View File

@@ -102,3 +102,19 @@ func (r *BackupConfigRepository) IsStorageUsing(storageID uuid.UUID) (bool, erro
return count > 0, nil
}
func (r *BackupConfigRepository) GetDatabasesIDsByStorageID(
storageID uuid.UUID,
) ([]uuid.UUID, error) {
var databasesIDs []uuid.UUID
if err := storage.
GetDb().
Table("backup_configs").
Where("storage_id = ?", storageID).
Pluck("database_id", &databasesIDs).Error; err != nil {
return nil, err
}
return databasesIDs, nil
}

View File

@@ -5,6 +5,7 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -17,6 +18,7 @@ type BackupConfigService struct {
backupConfigRepository *BackupConfigRepository
databaseService *databases.DatabaseService
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
dbStorageChangeListener BackupConfigStorageChangeListener
@@ -28,6 +30,17 @@ func (s *BackupConfigService) SetDatabaseStorageChangeListener(
s.dbStorageChangeListener = dbStorageChangeListener
}
func (s *BackupConfigService) GetStorageAttachedDatabasesIDs(
storageID uuid.UUID,
) ([]uuid.UUID, error) {
databasesIDs, err := s.backupConfigRepository.GetDatabasesIDsByStorageID(storageID)
if err != nil {
return nil, err
}
return databasesIDs, nil
}
func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
@@ -53,6 +66,16 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
return nil, errors.New("insufficient permissions to modify backup configuration")
}
if backupConfig.Storage != nil && backupConfig.Storage.ID != uuid.Nil {
storage, err := s.storageService.GetStorageByID(backupConfig.Storage.ID)
if err != nil {
return nil, err
}
if storage.WorkspaceID != *database.WorkspaceID {
return nil, errors.New("storage does not belong to the same workspace as the database")
}
}
return s.SaveBackupConfig(backupConfig)
}
@@ -129,6 +152,23 @@ func (s *BackupConfigService) IsStorageUsing(
return s.backupConfigRepository.IsStorageUsing(storageID)
}
func (s *BackupConfigService) CountDatabasesForStorage(
user *users_models.User,
storageID uuid.UUID,
) (int, error) {
_, err := s.storageService.GetStorage(user, storageID)
if err != nil {
return 0, err
}
databaseIDs, err := s.backupConfigRepository.GetDatabasesIDsByStorageID(storageID)
if err != nil {
return 0, err
}
return len(databaseIDs), nil
}
func (s *BackupConfigService) GetBackupConfigsWithEnabledBackups() ([]*BackupConfig, error) {
return s.backupConfigRepository.GetWithEnabledBackups()
}
@@ -176,6 +216,157 @@ func (s *BackupConfigService) initializeDefaultConfig(
return err
}
func (s *BackupConfigService) TransferDatabaseToWorkspace(
user *users_models.User,
databaseID uuid.UUID,
request *TransferDatabaseRequest,
) error {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return ErrDatabaseHasNoWorkspace
}
canManageSource, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManageSource {
return ErrInsufficientPermissionsInSourceWorkspace
}
canManageTarget, err := s.workspaceService.CanUserManageDBs(request.TargetWorkspaceID, user)
if err != nil {
return err
}
if !canManageTarget {
return ErrInsufficientPermissionsInTargetWorkspace
}
if err := s.validateTargetNotifiers(request); err != nil {
return err
}
backupConfig, err := s.GetBackupConfigByDbId(databaseID)
if err != nil {
return err
}
if request.IsTransferWithNotifiers {
s.transferNotifiers(user, database, request.TargetWorkspaceID)
}
if request.IsTransferWithStorage {
if backupConfig.StorageID == nil {
return ErrDatabaseHasNoStorage
}
attachedDatabasesIDs, err := s.GetStorageAttachedDatabasesIDs(*backupConfig.StorageID)
if err != nil {
return err
}
for _, dbID := range attachedDatabasesIDs {
if dbID != databaseID {
return ErrStorageHasOtherAttachedDatabases
}
}
err = s.storageService.TransferStorageToWorkspace(
user,
*backupConfig.StorageID,
request.TargetWorkspaceID,
&databaseID,
)
if err != nil {
return err
}
} else if request.TargetStorageID != nil {
targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID)
if err != nil {
return err
}
if targetStorage.WorkspaceID != request.TargetWorkspaceID {
return ErrTargetStorageNotInTargetWorkspace
}
backupConfig.StorageID = request.TargetStorageID
backupConfig.Storage = targetStorage
_, err = s.backupConfigRepository.Save(backupConfig)
if err != nil {
return err
}
} else {
return ErrTargetStorageNotSpecified
}
err = s.databaseService.TransferDatabaseToWorkspace(databaseID, request.TargetWorkspaceID)
if err != nil {
return err
}
if len(request.TargetNotifierIDs) > 0 {
err = s.assignTargetNotifiers(databaseID, request.TargetNotifierIDs)
if err != nil {
return err
}
}
return nil
}
func (s *BackupConfigService) transferNotifiers(
user *users_models.User,
database *databases.Database,
targetWorkspaceID uuid.UUID,
) {
for _, notifier := range database.Notifiers {
_ = s.notifierService.TransferNotifierToWorkspace(
user,
notifier.ID,
targetWorkspaceID,
&database.ID,
)
}
}
func (s *BackupConfigService) validateTargetNotifiers(request *TransferDatabaseRequest) error {
for _, notifierID := range request.TargetNotifierIDs {
notifier, err := s.notifierService.GetNotifierByID(notifierID)
if err != nil {
return err
}
if notifier.WorkspaceID != request.TargetWorkspaceID {
return ErrTargetNotifierNotInTargetWorkspace
}
}
return nil
}
func (s *BackupConfigService) assignTargetNotifiers(
databaseID uuid.UUID,
notifierIDs []uuid.UUID,
) error {
targetNotifiers := make([]notifiers.Notifier, 0, len(notifierIDs))
for _, notifierID := range notifierIDs {
notifier, err := s.notifierService.GetNotifierByID(notifierID)
if err != nil {
return err
}
targetNotifiers = append(targetNotifiers, *notifier)
}
return s.databaseService.UpdateDatabaseNotifiers(databaseID, targetNotifiers)
}
func storageIDsEqual(id1, id2 *uuid.UUID) bool {
if id1 == nil && id2 == nil {
return true

View File

@@ -0,0 +1,229 @@
package backups_config
import (
"fmt"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/storages"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
)
func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
router := createTestRouterWithStorage()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
Storage: storage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.StorageID)
assert.Equal(t, storage.ID, *response.StorageID)
}
func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
router := createTestRouterWithStorage()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
database := createTestDatabaseViaAPI("Test Database", workspace1.ID, owner1.Token, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage := createTestStorage(workspace2.ID)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
Storage: storage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "storage does not belong to the same workspace")
}
func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
router := createTestRouterWithStorage()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
Storage: storage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
testResp := test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", storage.ID.String()),
"Bearer "+owner.Token,
http.StatusBadRequest,
)
assert.Contains(
t,
string(testResp.Body),
"storage has attached databases and cannot be deleted",
)
}
func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
router := createTestRouterWithStorage()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
targetWorkspace := workspaces_testing.CreateTestWorkspace("Target Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
Storage: storage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
transferRequest := storages.TransferStorageRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s/transfer", storage.ID.String()),
"Bearer "+owner.Token,
transferRequest,
http.StatusBadRequest,
)
assert.Contains(
t,
string(testResp.Body),
"storage has attached databases and cannot be transferred",
)
}
func createTestRouterWithStorage() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetBackupConfigController(),
storages.GetStorageController(),
)
storages.SetupDependencies()
databases.SetupDependencies()
SetupDependencies()
return router
}

View File

@@ -26,6 +26,7 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/databases/test-connection-direct", c.TestDatabaseConnectionDirect)
router.POST("/databases/:id/copy", c.CopyDatabase)
router.GET("/databases/notifier/:id/is-using", c.IsNotifierUsing)
router.GET("/databases/notifier/:id/databases-count", c.CountDatabasesByNotifier)
router.POST("/databases/is-readonly", c.IsUserReadOnly)
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
}
@@ -299,6 +300,39 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"isUsing": isUsing})
}
// CountDatabasesByNotifier
// @Summary Count databases using a notifier
// @Description Get the count of databases that are using a specific notifier
// @Tags databases
// @Produce json
// @Param id path string true "Notifier ID"
// @Success 200 {object} map[string]int
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /databases/notifier/{id}/databases-count [get]
func (c *DatabaseController) CountDatabasesByNotifier(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
return
}
count, err := c.databaseService.CountDatabasesByNotifier(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"count": count})
}
// CopyDatabase
// @Summary Copy a database
// @Description Copy an existing database configuration
@@ -358,13 +392,13 @@ func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
return
}
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
isReadOnly, privileges, err := c.databaseService.IsUserReadOnly(user, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly, Privileges: privileges})
}
// CreateReadOnlyUser

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"testing"
@@ -11,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine {
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -88,20 +155,11 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = member.Token
}
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
var response Database
@@ -132,20 +190,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
testResp := test_utils.MakePostRequest(
@@ -737,7 +786,13 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -745,9 +800,9 @@ func createTestDatabaseViaAPI(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -780,21 +835,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
pgConfig := getTestPostgresConfig()
plainPassword := "testpassword"
pgConfig.Password = plainPassword
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
var createdDatabase Database
@@ -854,38 +902,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "PostgreSQL Database",
databaseType: DatabaseTypePostgres,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
pgConfig := getTestPostgresConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "original-password-secret",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
pgConfig := getTestPostgresConfig()
pgConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion17,
Host: "updated-host",
Port: 5433,
Username: "updated_user",
Password: "",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -895,7 +928,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)
@@ -905,36 +938,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MariaDB Database",
databaseType: DatabaseTypeMariadb,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
mariaConfig := getTestMariadbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "original-password-secret",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
mariaConfig := getTestMariadbConfig()
mariaConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion114,
Host: "updated-host",
Port: 3307,
Username: "updated_user",
Password: "",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -944,7 +964,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mariadb.Password)
@@ -954,40 +974,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MongoDB Database",
databaseType: DatabaseTypeMongodb,
createDatabase: func(workspaceID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: 27017,
Username: "root",
Password: "original-password-secret",
Database: "test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
mongoConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion8,
Host: "updated-host",
Port: 27018,
Username: "updated_user",
Password: "",
Database: "updated_test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -997,7 +1000,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "rootpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mongodb.Password)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
@@ -23,12 +24,13 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -94,6 +96,16 @@ func (m *MariadbDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -111,6 +123,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -131,15 +144,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MariadbDatabase) PopulateVersionIfEmpty(
func (m *MariadbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMariadbVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MariadbDatabase) PopulateVersion(
@@ -175,8 +221,8 @@ func (m *MariadbDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -185,17 +231,17 @@ func (m *MariadbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -205,33 +251,44 @@ func (m *MariadbDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "DELETE HISTORY", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MariadbDatabase) CreateReadOnlyUser(
@@ -261,7 +318,7 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
for attempt := range maxRetries {
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -326,10 +383,23 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MariadbDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for _, p := range strings.Split(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MariadbDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
tlsConfig = "skip-verify"
}
return fmt.Sprintf(
@@ -420,6 +490,99 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -18,6 +18,171 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, limitedUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, backupUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -49,13 +214,56 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MariadbDatabase{
Version: mariadbModel.Version,
Host: mariadbModel.Host,
Port: mariadbModel.Port,
Username: username,
Password: password,
Database: mariadbModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
dropUserSafe(container.DB, username)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -127,9 +335,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
@@ -382,6 +596,5 @@ func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
}
func dropUserSafe(db *sqlx.DB, username string) {
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
}

View File

@@ -5,7 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"regexp"
"strings"
"time"
"databasus-backend/internal/util/encryption"
@@ -95,6 +97,10 @@ func (m *MongodbDatabase) TestConnection(
}
m.Version = detectedVersion
if err := checkBackupPermissions(ctx, client, m.Username, m.Database, m.AuthDatabase); err != nil {
return err
}
return nil
}
@@ -134,14 +140,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MongodbDatabase) PopulateVersionIfEmpty(
func (m *MongodbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
}
@@ -185,10 +188,10 @@ func (m *MongodbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
uri := m.buildConnectionURI(password)
@@ -196,7 +199,7 @@ func (m *MongodbDatabase) IsUserReadOnly(
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if disconnectErr := client.Disconnect(ctx); disconnectErr != nil {
@@ -218,44 +221,153 @@ func (m *MongodbDatabase) IsUserReadOnly(
}},
}).Decode(&result)
if err != nil {
return false, fmt.Errorf("failed to get user info: %w", err)
return false, nil, fmt.Errorf("failed to get user info: %w", err)
}
writeRoles := []string{
"readWrite", "readWriteAnyDatabase", "dbAdmin", "dbAdminAnyDatabase",
"userAdmin", "userAdminAnyDatabase", "clusterAdmin", "root",
"dbOwner", "backup", "restore",
writeRoles := map[string]bool{
"readWrite": true,
"readWriteAnyDatabase": true,
"dbAdmin": true,
"dbAdminAnyDatabase": true,
"userAdmin": true,
"userAdminAnyDatabase": true,
"clusterAdmin": true,
"clusterManager": true,
"hostManager": true,
"root": true,
"dbOwner": true,
"restore": true,
"__system": true,
}
// Roles that are read-only for our backup purposes
// The "backup" role has insert/update on mms.backup collection but is needed for mongodump
readOnlyRoles := map[string]bool{
"read": true,
"backup": true,
}
writeActions := map[string]bool{
"insert": true,
"update": true,
"remove": true,
"createCollection": true,
"dropCollection": true,
"createIndex": true,
"dropIndex": true,
"convertToCapped": true,
"dropDatabase": true,
"renameCollection": true,
"createUser": true,
"dropUser": true,
"updateUser": true,
"grantRole": true,
"revokeRole": true,
"dropRole": true,
"createRole": true,
"updateRole": true,
"enableSharding": true,
"shardCollection": true,
"addShard": true,
"removeShard": true,
"shutdown": true,
"replSetReconfig": true,
"replSetStateChange": true,
}
var detectedRoles []string
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return true, nil
return true, detectedRoles, nil
}
user, ok := users[0].(bson.M)
if !ok {
return true, nil
return true, detectedRoles, nil
}
roles, ok := user["roles"].(bson.A)
if !ok {
return true, nil
return true, detectedRoles, nil
}
// Collect all role names and check for write roles
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
for _, writeRole := range writeRoles {
if roleName == writeRole {
return false, nil
if roleName != "" {
detectedRoles = append(detectedRoles, roleName)
}
}
// Check if any detected role is a write role
for _, roleName := range detectedRoles {
if writeRoles[roleName] {
return false, detectedRoles, nil
}
}
// If all roles are known read-only roles (read, backup), skip inherited privilege check
allRolesReadOnly := true
for _, roleName := range detectedRoles {
if !readOnlyRoles[roleName] {
allRolesReadOnly = false
break
}
}
if allRolesReadOnly && len(detectedRoles) > 0 {
return true, detectedRoles, nil
}
// Check inherited privileges for custom roles
var privResult bson.M
err = adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: m.Username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&privResult)
if err != nil {
return false, nil, fmt.Errorf("failed to get user privileges: %w", err)
}
privUsers, ok := privResult["users"].(bson.A)
if !ok || len(privUsers) == 0 {
return true, detectedRoles, nil
}
privUser, ok := privUsers[0].(bson.M)
if !ok {
return true, detectedRoles, nil
}
// Check inheritedPrivileges for write actions
inheritedPrivileges, ok := privUser["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && writeActions[actionStr] {
return false, detectedRoles, nil
}
}
}
}
return true, nil
return true, detectedRoles, nil
}
func (m *MongodbDatabase) CreateReadOnlyUser(
@@ -290,7 +402,7 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
adminDB := client.Database(authDB)
err = adminDB.RunCommand(ctx, bson.D{
@@ -332,20 +444,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
m.Database,
authDB,
tlsOption,
tlsParams,
)
}
@@ -356,19 +468,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
authDB,
tlsOption,
tlsParams,
)
}
@@ -413,6 +525,128 @@ func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.Mong
}
}
// checkBackupPermissions verifies the user has sufficient privileges for mongodump backup.
// Required: 'read' role on target database OR 'backup' role on admin OR 'readAnyDatabase' role.
func checkBackupPermissions(
ctx context.Context,
client *mongo.Client,
username, database, authDatabase string,
) error {
authDB := authDatabase
if authDB == "" {
authDB = "admin"
}
adminDB := client.Database(authDB)
var result bson.M
err := adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&result)
if err != nil {
return fmt.Errorf("failed to get user info: %w", err)
}
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return errors.New("insufficient permissions for backup. User not found")
}
user, ok := users[0].(bson.M)
if !ok {
return errors.New("insufficient permissions for backup. Could not parse user info")
}
// Check roles for backup permissions
roles, ok := user["roles"].(bson.A)
if !ok {
return errors.New("insufficient permissions for backup. No roles assigned")
}
backupRoles := map[string]bool{
"backup": true,
"root": true,
"readAnyDatabase": true,
"dbOwner": true,
"__system": true,
"clusterAdmin": true,
"readWriteAnyDatabase": true,
}
var userRoles []string
hasBackupRole := false
hasReadOnTargetDB := false
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
roleDB, _ := role["db"].(string)
if roleName != "" {
userRoles = append(userRoles, roleName)
}
if backupRoles[roleName] {
hasBackupRole = true
}
if roleName == "read" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
if roleName == "readWrite" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
}
if hasBackupRole || hasReadOnTargetDB {
return nil
}
// Check inherited privileges for 'find' action on target database
inheritedPrivileges, ok := user["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
resource, ok := priv["resource"].(bson.M)
if !ok {
continue
}
resourceDB, _ := resource["db"].(string)
resourceCluster, _ := resource["cluster"].(bool)
isTargetDB := resourceDB == database || resourceDB == "" || resourceCluster
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && actionStr == "find" && isTargetDB {
return nil
}
}
}
}
return fmt.Errorf(
"insufficient permissions for backup. Current roles: %s. Required: 'read' role on database '%s' OR 'backup' role on admin OR 'readAnyDatabase' role",
strings.Join(userRoles, ", "),
database,
)
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"net/url"
"os"
"strconv"
"strings"
@@ -19,6 +20,138 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("permission_test").Drop(ctx)
_, err := db.Collection("permission_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: limitedUsername},
{Key: "pwd", Value: limitedPassword},
{Key: "roles", Value: bson.A{}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("backup_test").Drop(ctx)
_, err := db.Collection("backup_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: backupUsername},
{Key: "pwd", Value: backupPassword},
{Key: "roles", Value: bson.A{
bson.D{
{Key: "role", Value: "read"},
{Key: "db", Value: container.Database},
},
}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -46,13 +179,52 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, roles, "Root user should have roles")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("readonly_check_test").Drop(ctx)
_, err := db.Collection("readonly_check_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
mongodbModel := createMongodbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
username, password, err := mongodbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MongodbDatabase{
Version: mongodbModel.Version,
Host: mongodbModel.Host,
Port: mongodbModel.Port,
Username: username,
Password: password,
Database: mongodbModel.Database,
AuthDatabase: mongodbModel.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, roles, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.NotEmpty(t, roles, "Read-only user should have roles (read, backup)")
dropUserSafe(container.Client, username, container.AuthDatabase)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -271,6 +443,7 @@ func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
}
@@ -281,7 +454,8 @@ func connectWithCredentials(
) *mongo.Client {
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, container.Host, container.Port,
url.QueryEscape(username), url.QueryEscape(password),
container.Host, container.Port,
container.Database, container.AuthDatabase,
)

View File

@@ -7,6 +7,8 @@ import (
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
"databasus-backend/internal/util/encryption"
@@ -22,12 +24,13 @@ type MysqlDatabase struct {
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MysqlDatabase) TableName() string {
@@ -93,6 +96,16 @@ func (m *MysqlDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -110,6 +123,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -130,15 +144,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields(
return nil
}
func (m *MysqlDatabase) PopulateVersionIfEmpty(
func (m *MysqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMysqlVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MysqlDatabase) PopulateVersion(
@@ -174,8 +221,8 @@ func (m *MysqlDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -184,17 +231,17 @@ func (m *MysqlDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -204,33 +251,45 @@ func (m *MysqlDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MysqlDatabase) CreateReadOnlyUser(
@@ -259,7 +318,7 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -325,10 +384,23 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MysqlDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for p := range strings.SplitSeq(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
tlsConfig = "skip-verify"
}
return fmt.Sprintf(
@@ -388,6 +460,99 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -18,6 +18,165 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", limitedUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", backupUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -42,13 +201,57 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MysqlDatabase{
Version: mysqlModel.Version,
Host: mysqlModel.Host,
Port: mysqlModel.Port,
Username: username,
Password: password,
Database: mysqlModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -109,9 +312,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",

View File

@@ -137,16 +137,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
return nil
}
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
// PopulateDbData detects and sets the PostgreSQL version.
// This should be called before encrypting sensitive fields.
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
func (p *PostgresqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.Version != "" {
return nil
}
return p.PopulateVersion(logger, encryptor, databaseID)
}
@@ -192,29 +189,33 @@ func (p *PostgresqlDatabase) PopulateVersion(
// IsUserReadOnly checks if the database user has read-only privileges.
//
// This method performs a comprehensive security check by examining:
// - Role-level attributes (superuser, createrole, createdb)
// - Role-level attributes (superuser, createrole, createdb, bypassrls, replication)
// - Database-level privileges (CREATE, TEMP)
// - Schema-level privileges (CREATE on any non-system schema)
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
// - Function-level privileges (EXECUTE on SECURITY DEFINER functions)
//
// A user is considered read-only only if they have ZERO write privileges
// across all three levels. This ensures the database user follows the
// across all levels. This ensures the database user follows the
// principle of least privilege for backup operations.
//
// Returns: (isReadOnly, detectedPrivileges, error)
func (p *PostgresqlDatabase) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
@@ -222,22 +223,38 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
}
}()
var privileges []string
// LEVEL 1: Check role-level attributes
var isSuperuser, canCreateRole, canCreateDB bool
var isSuperuser, canCreateRole, canCreateDB, canBypassRLS, canReplication bool
err = conn.QueryRow(ctx, `
SELECT
rolsuper,
rolcreaterole,
rolcreatedb
rolcreatedb,
rolbypassrls,
rolreplication
FROM pg_roles
WHERE rolname = current_user
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB, &canBypassRLS, &canReplication)
if err != nil {
return false, fmt.Errorf("failed to check role attributes: %w", err)
return false, nil, fmt.Errorf("failed to check role attributes: %w", err)
}
if isSuperuser || canCreateRole || canCreateDB {
return false, nil
if isSuperuser {
privileges = append(privileges, "SUPERUSER")
}
if canCreateRole {
privileges = append(privileges, "CREATEROLE")
}
if canCreateDB {
privileges = append(privileges, "CREATEDB")
}
if canBypassRLS {
privileges = append(privileges, "BYPASSRLS")
}
if canReplication {
privileges = append(privileges, "REPLICATION")
}
// LEVEL 2: Check database-level privileges
@@ -248,46 +265,34 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
`).Scan(&canCreate, &canTemp)
if err != nil {
return false, fmt.Errorf("failed to check database privileges: %w", err)
return false, nil, fmt.Errorf("failed to check database privileges: %w", err)
}
if canCreate || canTemp {
return false, nil
if canCreate {
privileges = append(privileges, "CREATE (database)")
}
if canTemp {
privileges = append(privileges, "TEMP")
}
// LEVEL 2.5: Check schema-level CREATE privileges
schemaRows, err := conn.Query(ctx, `
SELECT DISTINCT nspname
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
`)
var hasSchemaCreate bool
err = conn.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
)
`).Scan(&hasSchemaCreate)
if err != nil {
return false, fmt.Errorf("failed to check schema privileges: %w", err)
return false, nil, fmt.Errorf("failed to check schema privileges: %w", err)
}
defer schemaRows.Close()
// If user has CREATE privilege on any schema, they're not read-only
if schemaRows.Next() {
return false, nil
}
if err := schemaRows.Err(); err != nil {
return false, fmt.Errorf("error iterating schema privileges: %w", err)
if hasSchemaCreate {
privileges = append(privileges, "CREATE (schema)")
}
// LEVEL 3: Check table-level write permissions
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, fmt.Errorf("failed to check table privileges: %w", err)
}
defer rows.Close()
writePrivileges := map[string]bool{
"INSERT": true,
"UPDATE": true,
@@ -297,22 +302,56 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
"TRIGGER": true,
}
var tablePrivileges []string
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, nil, fmt.Errorf("failed to check table privileges: %w", err)
}
for rows.Next() {
var privilege string
if err := rows.Scan(&privilege); err != nil {
return false, fmt.Errorf("failed to scan privilege: %w", err)
}
if writePrivileges[privilege] {
return false, nil
rows.Close()
return false, nil, fmt.Errorf("failed to scan privilege: %w", err)
}
tablePrivileges = append(tablePrivileges, privilege)
}
rows.Close()
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating privileges: %w", err)
return false, nil, fmt.Errorf("error iterating privileges: %w", err)
}
return true, nil
for _, privilege := range tablePrivileges {
if writePrivileges[privilege] {
privileges = append(privileges, privilege)
}
}
// LEVEL 4: Check for EXECUTE privilege on functions that are SECURITY DEFINER
var funcCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_proc p
JOIN pg_namespace n ON p.pronamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
AND p.prosecdef = true
AND has_function_privilege(current_user, p.oid, 'EXECUTE')
`).Scan(&funcCount)
if err != nil {
return false, nil, fmt.Errorf("failed to check function privileges: %w", err)
}
if funcCount > 0 {
privileges = append(privileges, "EXECUTE (SECURITY DEFINER)")
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
@@ -383,7 +422,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := conn.Begin(ctx)
if err != nil {
@@ -631,13 +670,9 @@ func testSingleDatabaseConnection(
}
postgresDb.Version = detectedVersion
// Test if we can perform basic operations (like pg_dump would need)
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
return fmt.Errorf(
"basic operations test failed for database '%s': %w",
*postgresDb.Database,
err,
)
// Verify user has sufficient permissions for backup operations
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
return err
}
return nil
@@ -670,18 +705,73 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
}
}
// testBasicOperations tests basic operations that backup tools need
func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) error {
var hasCreatePriv bool
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
var missingPrivileges []string
// Check CONNECT privilege on database
var hasConnect bool
err := conn.QueryRow(ctx, "SELECT has_database_privilege(current_user, current_database(), 'CONNECT')").
Scan(&hasCreatePriv)
Scan(&hasConnect)
if err != nil {
return fmt.Errorf("cannot check database privileges: %w", err)
}
if !hasConnect {
missingPrivileges = append(missingPrivileges, "CONNECT on database")
}
if !hasCreatePriv {
return fmt.Errorf("user does not have CONNECT privilege on database '%s'", dbName)
// Check USAGE privilege on at least one non-system schema
var schemaCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
`).Scan(&schemaCount)
if err != nil {
return fmt.Errorf("cannot check schema privileges: %w", err)
}
if schemaCount == 0 {
missingPrivileges = append(missingPrivileges, "USAGE on at least one schema")
}
// Check SELECT privilege on at least one table (if tables exist)
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
var tableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
`).Scan(&tableCount)
if err != nil {
return fmt.Errorf("cannot check table count: %w", err)
}
if tableCount > 0 {
// Check if user has SELECT on at least one of these tables
var selectableTableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: CONNECT on database, USAGE on schemas, SELECT on tables",
strings.Join(missingPrivileges, ", "),
)
}
return nil
@@ -695,16 +785,22 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password s
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
"host=%s port=%d user=%s password='%s' dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
p.Host,
p.Port,
p.Username,
password,
escapeConnectionStringValue(password),
dbName,
sslMode,
)
}
func escapeConnectionStringValue(value string) string {
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `'`, `\'`)
return value
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -13,11 +13,236 @@ import (
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"databasus-backend/internal/config"
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
passwordWithSpaces := "test password with spaces"
usernameWithSpaces := fmt.Sprintf("testuser_spaces_%s", uuid.New().String()[:8])
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS password_test CASCADE;
CREATE TABLE password_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO password_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
usernameWithSpaces,
passwordWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
usernameWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, usernameWithSpaces))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum("16"),
Host: container.Host,
Port: container.Port,
Username: usernameWithSpaces,
Password: passwordWithSpaces,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS permission_test CASCADE;
CREATE TABLE permission_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO permission_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_user_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), "insufficient permissions")
}
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS backup_test CASCADE;
CREATE TABLE backup_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO backup_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_user_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, backupUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -44,13 +269,60 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Admin user should not be read-only")
assert.NotEmpty(t, privileges, "Admin user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_check_test CASCADE;
CREATE TABLE readonly_check_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_check_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -105,9 +377,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
@@ -142,7 +420,6 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -186,7 +463,6 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -202,8 +478,10 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
defer container.DB.Close()
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS schema_a;
CREATE SCHEMA IF NOT EXISTS schema_b;
DROP SCHEMA IF EXISTS schema_a CASCADE;
DROP SCHEMA IF EXISTS schema_b CASCADE;
CREATE SCHEMA schema_a;
CREATE SCHEMA schema_b;
CREATE TABLE schema_a.table_a (id INT, data TEXT);
CREATE TABLE schema_b.table_b (id INT, data TEXT);
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
@@ -234,7 +512,6 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "data_b", dataB)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -341,7 +618,7 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
)
adminDB, err := sqlx.Connect("postgres", dsn)
assert.NoError(t, err)
require.NoError(t, err)
defer adminDB.Close()
tableName := fmt.Sprintf(
@@ -483,6 +760,7 @@ func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
Password: container.Password,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
}

View File

@@ -39,4 +39,5 @@ func GetDatabaseController() *DatabaseController {
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
}

View File

@@ -6,5 +6,6 @@ type CreateReadOnlyUserResponse struct {
}
type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
IsReadOnly bool `json:"isReadOnly"`
Privileges []string `json:"privileges"`
}

View File

@@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
return nil
}
func (d *Database) PopulateVersionIfEmpty(
func (d *Database) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
if d.Postgresql != nil {
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Postgresql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mysql != nil {
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mysql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mariadb != nil {
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mariadb.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mongodb != nil {
return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mongodb.PopulateDbData(logger, encryptor, d.ID)
}
return nil
}

View File

@@ -243,3 +243,19 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
return databases, nil
}
func (r *DatabaseRepository) GetDatabasesIDsByNotifierID(
notifierID uuid.UUID,
) ([]uuid.UUID, error) {
var databasesIDs []uuid.UUID
if err := storage.
GetDb().
Table("database_notifiers").
Where("notifier_id = ?", notifierID).
Pluck("database_id", &databasesIDs).Error; err != nil {
return nil, err
}
return databasesIDs, nil
}

View File

@@ -52,6 +52,17 @@ func (s *DatabaseService) AddDbCopyListener(
s.dbCopyListener = append(s.dbCopyListener, dbCopyListener)
}
func (s *DatabaseService) GetNotifierAttachedDatabasesIDs(
notifierID uuid.UUID,
) ([]uuid.UUID, error) {
databasesIDs, err := s.dbRepository.GetDatabasesIDsByNotifierID(notifierID)
if err != nil {
return nil, err
}
return databasesIDs, nil
}
func (s *DatabaseService) CreateDatabase(
user *users_models.User,
workspaceID uuid.UUID,
@@ -71,8 +82,8 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -126,14 +137,20 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
for _, notifier := range database.Notifiers {
if notifier.WorkspaceID != *existingDatabase.WorkspaceID {
return errors.New("notifier does not belong to this workspace")
}
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return err
}
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -251,6 +268,23 @@ func (s *DatabaseService) IsNotifierUsing(
return s.dbRepository.IsNotifierUsing(notifierID)
}
func (s *DatabaseService) CountDatabasesByNotifier(
user *users_models.User,
notifierID uuid.UUID,
) (int, error) {
_, err := s.notifierService.GetNotifier(user, notifierID)
if err != nil {
return 0, err
}
databaseIDs, err := s.dbRepository.GetDatabasesIDsByNotifierID(notifierID)
if err != nil {
return 0, err
}
return len(databaseIDs), nil
}
func (s *DatabaseService) TestDatabaseConnection(
user *users_models.User,
databaseID uuid.UUID,
@@ -481,6 +515,48 @@ func (s *DatabaseService) CopyDatabase(
return copiedDatabase, nil
}
func (s *DatabaseService) TransferDatabaseToWorkspace(
databaseID uuid.UUID,
targetWorkspaceID uuid.UUID,
) error {
database, err := s.dbRepository.FindByID(databaseID)
if err != nil {
return err
}
sourceWorkspaceID := database.WorkspaceID
database.WorkspaceID = &targetWorkspaceID
_, err = s.dbRepository.Save(database)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
database.Name, sourceWorkspaceID, targetWorkspaceID),
nil,
&targetWorkspaceID,
)
return nil
}
func (s *DatabaseService) UpdateDatabaseNotifiers(
databaseID uuid.UUID,
newNotifiers []notifiers.Notifier,
) error {
database, err := s.dbRepository.FindByID(databaseID)
if err != nil {
return err
}
database.Notifiers = newNotifiers
_, err = s.dbRepository.Save(database)
return err
}
func (s *DatabaseService) SetHealthStatus(
databaseID uuid.UUID,
healthStatus *HealthStatus,
@@ -518,17 +594,17 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
func (s *DatabaseService) IsUserReadOnly(
user *users_models.User,
database *Database,
) (bool, error) {
) (bool, []string, error) {
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return false, err
return false, nil, err
}
if existingDatabase.WorkspaceID == nil {
return false, errors.New("cannot check user for database without workspace")
return false, nil, errors.New("cannot check user for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
@@ -536,20 +612,20 @@ func (s *DatabaseService) IsUserReadOnly(
user,
)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this database")
return false, nil, errors.New("insufficient permissions to access this database")
}
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
return false, errors.New("database does not belong to this workspace")
return false, nil, errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return false, err
return false, nil, err
}
usingDatabase = existingDatabase
@@ -557,10 +633,10 @@ func (s *DatabaseService) IsUserReadOnly(
if database.WorkspaceID != nil {
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this workspace")
return false, nil, errors.New("insufficient permissions to access this workspace")
}
}
@@ -600,7 +676,7 @@ func (s *DatabaseService) IsUserReadOnly(
usingDatabase.ID,
)
default:
return false, errors.New("read-only check not supported for this database type")
return false, nil, errors.New("read-only check not supported for this database type")
}
}

View File

@@ -1,6 +1,12 @@
package databases
import (
"fmt"
"strconv"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -9,6 +15,71 @@ import (
"github.com/google/uuid"
)
func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func GetTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func GetTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func CreateTestDatabase(
workspaceID uuid.UUID,
storage *storages.Storage,
@@ -18,16 +89,7 @@ func CreateTestDatabase(
WorkspaceID: &workspaceID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 1,
},
Postgresql: GetTestPostgresConfig(),
Notifiers: []notifiers.Notifier{
*notifier,
},

View File

@@ -1,7 +1,9 @@
package disk
import (
"databasus-backend/internal/config"
"fmt"
"path/filepath"
"runtime"
"github.com/shirou/gopsutil/v4/disk"
@@ -12,10 +14,14 @@ type DiskService struct{}
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
platform := s.detectPlatform()
// Set path based on platform
path := "/"
var path string
if platform == PlatformWindows {
path = "C:\\"
} else {
// Use databasus-data folder location for Linux (Docker)
cfg := config.GetEnv()
path = filepath.Dir(cfg.DataFolder) // Gets /databasus-data from /databasus-data/backups
}
diskUsage, err := disk.Usage(path)

View File

@@ -12,13 +12,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -205,20 +203,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -10,13 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -293,20 +291,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -1,6 +1,8 @@
package notifiers
import (
"errors"
users_middleware "databasus-backend/internal/features/users/middleware"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"net/http"
@@ -20,6 +22,7 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/notifiers/:id", c.GetNotifier)
router.DELETE("/notifiers/:id", c.DeleteNotifier)
router.POST("/notifiers/:id/test", c.SendTestNotification)
router.POST("/notifiers/:id/transfer", c.TransferNotifierToWorkspace)
router.POST("/notifiers/direct-test", c.SendTestNotificationDirect)
}
@@ -55,7 +58,7 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
}
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToManageNotifier) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -93,7 +96,7 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
notifier, err := c.notifierService.GetNotifier(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view notifier in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToViewNotifier) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -137,7 +140,7 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToViewNotifiers) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -174,7 +177,7 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
}
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToManageNotifier) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -211,7 +214,7 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
}
if err := c.notifierService.SendTestNotification(user, id); err != nil {
if err.Error() == "insufficient permissions to test notifier in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToTestNotifier) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -222,6 +225,57 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "test notification sent successfully"})
}
// TransferNotifierToWorkspace
// @Summary Transfer notifier to another workspace
// @Description Transfer a notifier from one workspace to another
// @Tags notifiers
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param id path string true "Notifier ID"
// @Param request body TransferNotifierRequest true "Target workspace ID"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id}/transfer [post]
func (c *NotifierController) TransferNotifierToWorkspace(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
return
}
var request TransferNotifierRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.TargetWorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "targetWorkspaceId is required"})
return
}
if err := c.notifierService.TransferNotifierToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "notifier transferred successfully"})
}
// SendTestNotificationDirect
// @Summary Send test notification directly
// @Description Send a test notification using a notifier object provided in the request

View File

@@ -202,164 +202,161 @@ func Test_SendTestNotificationExisting_NotificationSent(t *testing.T) {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_ViewerCanViewNotifiers_ButCannotModify(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
viewer,
users_enums.WorkspaceRoleViewer,
owner.Token,
router,
)
func Test_WorkspaceRolePermissions_Notifiers(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
canCreate bool
canUpdate bool
canDelete bool
}{
{
name: "owner can manage notifiers",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "admin can manage notifiers",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "member can manage notifiers",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "viewer can view but cannot modify notifiers",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
canCreate: false,
canUpdate: false,
canDelete: false,
},
{
name: "global admin can manage notifiers",
workspaceRole: nil,
isGlobalAdmin: true,
canCreate: true,
canUpdate: true,
canDelete: true,
},
}
notifier := createNewNotifier(workspace.ID)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createRouter()
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Viewer can GET notifiers
var notifiers []Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
"Bearer "+viewer.Token,
http.StatusOK,
&notifiers,
)
assert.Len(t, notifiers, 1)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(
workspace,
testUser,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = testUser.Token
}
// Viewer cannot CREATE notifier
newNotifier := createNewNotifier(workspace.ID)
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, *newNotifier, http.StatusForbidden,
)
// Owner creates initial notifier for all test cases
var ownerNotifier Notifier
notifier := createNewNotifier(workspace.ID)
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/notifiers", "Bearer "+owner.Token,
*notifier, http.StatusOK, &ownerNotifier,
)
// Viewer cannot UPDATE notifier
savedNotifier.Name = "Updated by viewer"
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, savedNotifier, http.StatusForbidden,
)
// Test GET notifiers
var notifiers []Notifier
test_utils.MakeGetRequestAndUnmarshal(
t, router,
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
"Bearer "+testUserToken, http.StatusOK, &notifiers,
)
assert.Len(t, notifiers, 1)
// Viewer cannot DELETE notifier
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+viewer.Token,
http.StatusForbidden,
)
// Test CREATE notifier
createStatusCode := http.StatusOK
if !tt.canCreate {
createStatusCode = http.StatusForbidden
}
newNotifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
if tt.canCreate {
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
*newNotifier, createStatusCode, &savedNotifier,
)
assert.NotEmpty(t, savedNotifier.ID)
} else {
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
*newNotifier, createStatusCode,
)
}
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
// Test UPDATE notifier
updateStatusCode := http.StatusOK
if !tt.canUpdate {
updateStatusCode = http.StatusForbidden
}
ownerNotifier.Name = "Updated by test user"
if tt.canUpdate {
var updatedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
ownerNotifier, updateStatusCode, &updatedNotifier,
)
assert.Equal(t, "Updated by test user", updatedNotifier.Name)
} else {
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+testUserToken,
ownerNotifier, updateStatusCode,
)
}
func Test_MemberCanManageNotifiers(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
// Test DELETE notifier
deleteStatusCode := http.StatusOK
if !tt.canDelete {
deleteStatusCode = http.StatusForbidden
}
test_utils.MakeDeleteRequest(
t, router,
fmt.Sprintf("/api/v1/notifiers/%s", ownerNotifier.ID.String()),
"Bearer "+testUserToken, deleteStatusCode,
)
notifier := createNewNotifier(workspace.ID)
// Member can CREATE notifier
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+member.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
assert.NotEmpty(t, savedNotifier.ID)
// Member can UPDATE notifier
savedNotifier.Name = "Updated by member"
var updatedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+member.Token,
savedNotifier,
http.StatusOK,
&updatedNotifier,
)
assert.Equal(t, "Updated by member", updatedNotifier.Name)
// Member can DELETE notifier
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_AdminCanManageNotifiers(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
admin,
users_enums.WorkspaceRoleAdmin,
owner.Token,
router,
)
notifier := createNewNotifier(workspace.ID)
// Admin can CREATE, UPDATE, DELETE
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+admin.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
savedNotifier.Name = "Updated by admin"
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+admin.Token, savedNotifier, http.StatusOK,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Cleanup
if tt.canCreate {
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
}
if !tt.canDelete {
deleteNotifier(t, router, ownerNotifier.ID, workspace.ID, owner.Token)
}
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func Test_UserNotInWorkspace_CannotAccessNotifiers(t *testing.T) {
@@ -965,6 +962,192 @@ func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
}
}
func Test_TransferNotifier_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
sourceRole *users_enums.WorkspaceRole
targetRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "owner in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "admin in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "member in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "viewer in both workspaces cannot transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusForbidden,
},
{
name: "global admin can transfer",
sourceRole: nil,
targetRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createRouter()
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
sourceOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
targetOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
sourceWorkspace := workspaces_testing.CreateTestWorkspace(
"Source Workspace",
sourceOwner,
router,
)
targetWorkspace := workspaces_testing.CreateTestWorkspace(
"Target Workspace",
targetOwner,
router,
)
notifier := createNewNotifier(sourceWorkspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+sourceOwner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
testUserToken = testUser.Token
}
request := TransferNotifierRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s/transfer", savedNotifier.ID.String()),
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
if tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "transferred successfully")
var retrievedNotifier Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+targetOwner.Token,
http.StatusOK,
&retrievedNotifier,
)
assert.Equal(t, targetWorkspace.ID, retrievedNotifier.WorkspaceID)
deleteNotifier(t, router, savedNotifier.ID, targetWorkspace.ID, targetOwner.Token)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
deleteNotifier(t, router, savedNotifier.ID, sourceWorkspace.ID, sourceOwner.Token)
}
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
})
}
}
func Test_TransferNotifierNotManagableWorkspace_TransferFailed(t *testing.T) {
router := createRouter()
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
userA := users_testing.CreateTestUser(users_enums.UserRoleMember)
userB := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", userA, router)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", userB, router)
notifier := createNewNotifier(workspace1.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+userA.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
request := TransferNotifierRequest{
TargetWorkspaceID: workspace2.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s/transfer", savedNotifier.ID.String()),
"Bearer "+userA.Token,
request,
http.StatusForbidden,
)
assert.Contains(
t,
string(testResp.Body),
"insufficient permissions to manage notifier in target workspace",
)
deleteNotifier(t, router, savedNotifier.ID, workspace1.ID, userA.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
type mockNotifierDatabaseCounter struct{}
func (m *mockNotifierDatabaseCounter) GetNotifierAttachedDatabasesIDs(
notifierID uuid.UUID,
) ([]uuid.UUID, error) {
return []uuid.UUID{}, nil
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -979,6 +1162,7 @@ func createRouter() *gin.Engine {
}
audit_logs.SetupDependencies()
GetNotifierService().SetNotifierDatabaseCounter(&mockNotifierDatabaseCounter{})
return router
}

View File

@@ -14,6 +14,7 @@ var notifierService = &NotifierService{
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
nil,
}
var notifierController = &NotifierController{
notifierService,

View File

@@ -0,0 +1,7 @@
package notifiers
import "github.com/google/uuid"
type TransferNotifierRequest struct {
TargetWorkspaceID uuid.UUID `json:"targetWorkspaceId" binding:"required"`
}

View File

@@ -0,0 +1,36 @@
package notifiers
import "errors"
var (
ErrInsufficientPermissionsToManageNotifier = errors.New(
"insufficient permissions to manage notifier in this workspace",
)
ErrInsufficientPermissionsToViewNotifier = errors.New(
"insufficient permissions to view notifier in this workspace",
)
ErrInsufficientPermissionsToViewNotifiers = errors.New(
"insufficient permissions to view notifiers in this workspace",
)
ErrInsufficientPermissionsToTestNotifier = errors.New(
"insufficient permissions to test notifier in this workspace",
)
ErrNotifierDoesNotBelongToWorkspace = errors.New(
"notifier does not belong to this workspace",
)
ErrInsufficientPermissionsInSourceWorkspace = errors.New(
"insufficient permissions to manage notifier in source workspace",
)
ErrInsufficientPermissionsInTargetWorkspace = errors.New(
"insufficient permissions to manage notifier in target workspace",
)
ErrNotifierHasAttachedDatabases = errors.New(
"notifier has attached databases and cannot be deleted",
)
ErrNotifierHasAttachedDatabasesCannotTransfer = errors.New(
"notifier has attached databases and cannot be transferred",
)
ErrNotifierHasOtherAttachedDatabasesCannotTransfer = errors.New(
"notifier has other attached databases and cannot be transferred",
)
)

View File

@@ -3,6 +3,8 @@ package notifiers
import (
"databasus-backend/internal/util/encryption"
"log/slog"
"github.com/google/uuid"
)
type NotificationSender interface {
@@ -19,3 +21,7 @@ type NotificationSender interface {
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}
type NotifierDatabaseCounter interface {
GetNotifierAttachedDatabasesIDs(notifierID uuid.UUID) ([]uuid.UUID, error)
}

View File

@@ -1,7 +1,6 @@
package notifiers
import (
"errors"
"fmt"
"log/slog"
@@ -14,11 +13,18 @@ import (
)
type NotifierService struct {
notifierRepository *NotifierRepository
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
notifierRepository *NotifierRepository
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
notifierDatabaseCounter NotifierDatabaseCounter
}
func (s *NotifierService) SetNotifierDatabaseCounter(
notifierDatabaseCounter NotifierDatabaseCounter,
) {
s.notifierDatabaseCounter = notifierDatabaseCounter
}
func (s *NotifierService) SaveNotifier(
@@ -31,7 +37,7 @@ func (s *NotifierService) SaveNotifier(
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
return ErrInsufficientPermissionsToManageNotifier
}
isUpdate := notifier.ID != uuid.Nil
@@ -43,7 +49,7 @@ func (s *NotifierService) SaveNotifier(
}
if existingNotifier.WorkspaceID != workspaceID {
return errors.New("notifier does not belong to this workspace")
return ErrNotifierDoesNotBelongToWorkspace
}
existingNotifier.Update(notifier)
@@ -106,7 +112,17 @@ func (s *NotifierService) DeleteNotifier(
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
return ErrInsufficientPermissionsToManageNotifier
}
attachedDatabasesIDs, err := s.notifierDatabaseCounter.GetNotifierAttachedDatabasesIDs(
notifier.ID,
)
if err != nil {
return err
}
if len(attachedDatabasesIDs) > 0 {
return ErrNotifierHasAttachedDatabases
}
err = s.notifierRepository.Delete(notifier)
@@ -137,13 +153,17 @@ func (s *NotifierService) GetNotifier(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifier in this workspace")
return nil, ErrInsufficientPermissionsToViewNotifier
}
notifier.HideSensitiveData()
return notifier, nil
}
func (s *NotifierService) GetNotifierByID(id uuid.UUID) (*Notifier, error) {
return s.notifierRepository.FindByID(id)
}
func (s *NotifierService) GetNotifiers(
user *users_models.User,
workspaceID uuid.UUID,
@@ -153,7 +173,7 @@ func (s *NotifierService) GetNotifiers(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifiers in this workspace")
return nil, ErrInsufficientPermissionsToViewNotifiers
}
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
@@ -182,7 +202,7 @@ func (s *NotifierService) SendTestNotification(
return err
}
if !canView {
return errors.New("insufficient permissions to test notifier in this workspace")
return ErrInsufficientPermissionsToTestNotifier
}
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
@@ -210,7 +230,7 @@ func (s *NotifierService) SendTestNotificationToNotifier(
}
if existingNotifier.WorkspaceID != notifier.WorkspaceID {
return errors.New("notifier does not belong to this workspace")
return ErrNotifierDoesNotBelongToWorkspace
}
existingNotifier.Update(notifier)
@@ -269,6 +289,70 @@ func (s *NotifierService) SendNotification(
}
}
func (s *NotifierService) TransferNotifierToWorkspace(
user *users_models.User,
notifierID uuid.UUID,
targetWorkspaceID uuid.UUID,
transferingWithDbID *uuid.UUID,
) error {
existingNotifier, err := s.notifierRepository.FindByID(notifierID)
if err != nil {
return err
}
canManageSource, err := s.workspaceService.CanUserManageDBs(existingNotifier.WorkspaceID, user)
if err != nil {
return err
}
if !canManageSource {
return ErrInsufficientPermissionsInSourceWorkspace
}
canManageTarget, err := s.workspaceService.CanUserManageDBs(targetWorkspaceID, user)
if err != nil {
return err
}
if !canManageTarget {
return ErrInsufficientPermissionsInTargetWorkspace
}
attachedDatabasesIDs, err := s.notifierDatabaseCounter.GetNotifierAttachedDatabasesIDs(
existingNotifier.ID,
)
if err != nil {
return err
}
if transferingWithDbID != nil {
for _, dbID := range attachedDatabasesIDs {
if dbID != *transferingWithDbID {
return ErrNotifierHasOtherAttachedDatabasesCannotTransfer
}
}
} else {
if len(attachedDatabasesIDs) > 0 {
return ErrNotifierHasAttachedDatabasesCannotTransfer
}
}
sourceWorkspaceID := existingNotifier.WorkspaceID
existingNotifier.WorkspaceID = targetWorkspaceID
_, err = s.notifierRepository.Save(existingNotifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
&user.ID,
&targetWorkspaceID,
)
return nil
}
func (s *NotifierService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
if err != nil {

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,10 +16,12 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
@@ -35,18 +38,6 @@ import (
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -250,6 +241,124 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
assert.True(t, found, "Audit log for restore not found")
}
func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
tests := []struct {
name string
dbType databases.DatabaseType
cpuCount int
expectDiskValidated bool
}{
{
name: "PostgreSQL_CPU4_SpaceValidated",
dbType: databases.DatabaseTypePostgres,
cpuCount: 4,
expectDiskValidated: true,
},
{
name: "PostgreSQL_CPU1_SpaceNotValidated",
dbType: databases.DatabaseTypePostgres,
cpuCount: 1,
expectDiskValidated: false,
},
{
name: "MySQL_SpaceNotValidated",
dbType: databases.DatabaseTypeMysql,
cpuCount: 3,
expectDiskValidated: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var backup *backups.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
request = RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: tc.cpuCount,
},
}
} else {
mysqlDB := createTestMySQLDatabase("Test MySQL DB", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "password",
},
}
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
expectedStatus := http.StatusOK
if tc.expectDiskValidated {
expectedStatus = http.StatusBadRequest
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
expectedStatus,
)
bodyStr := string(testResp.Body)
if tc.expectDiskValidated {
assert.Contains(t, bodyStr, "is required")
assert.Contains(t, bodyStr, "is available")
assert.Contains(t, bodyStr, "disk space")
} else {
assert.Contains(t, bodyStr, "restore started successfully")
}
})
}
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
@@ -283,20 +392,11 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(
@@ -321,6 +421,64 @@ func createTestDatabase(
return &database
}
func createTestMySQLDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
env := config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MYSQL_80_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypeMysql,
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf(
"Failed to create MySQL database. Status: %d, Body: %s",
w.Code,
w.Body.String(),
),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,

View File

@@ -5,6 +5,7 @@ import (
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/restores/usecases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -24,6 +25,7 @@ var restoreService = &RestoreService{
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
disk.GetDiskService(),
}
var restoreController = &RestoreController{
restoreService,

View File

@@ -5,6 +5,7 @@ import (
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/disk"
"databasus-backend/internal/features/restores/enums"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/restores/usecases"
@@ -32,6 +33,7 @@ type RestoreService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
diskService *disk.DiskService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
@@ -126,6 +128,11 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
// Validate disk space before starting restore
if err := s.validateDiskSpace(backup, requestDTO); err != nil {
return err
}
go func() {
if err := s.RestoreBackup(backup, requestDTO); err != nil {
s.logger.Error("Failed to restore backup", "error", err)
@@ -222,8 +229,8 @@ func (s *RestoreService) RestoreBackup(
Mongodb: requestDTO.MongodbDatabase,
}
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
isExcludeExtensions := false
@@ -361,3 +368,58 @@ func (s *RestoreService) validateVersionCompatibility(
}
return nil
}
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:
// - CPU > 1 (parallel jobs require file)
// - IsExcludeExtensions (TOC filtering requires file)
// Other databases and PostgreSQL with CPU=1 without extension exclusion stream directly
if requestDTO.PostgresqlDatabase == nil {
return nil
}
needsFileBased := requestDTO.PostgresqlDatabase.CpuCount > 1 ||
requestDTO.PostgresqlDatabase.IsExcludeExtensions
if !needsFileBased {
return nil
}
diskUsage, err := s.diskService.GetDiskUsage()
if err != nil {
return fmt.Errorf("failed to check disk space: %w", err)
}
// Convert backup size from MB to bytes
backupSizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
// Calculate required space: backup size + 10% buffer
bufferBytes := int64(float64(backupSizeBytes) * 0.1)
requiredBytes := backupSizeBytes + bufferBytes
// Ensure minimum of 1 GB total (even if backup is small)
minRequiredBytes := int64(1024 * 1024 * 1024) // 1 GB
if requiredBytes < minRequiredBytes {
requiredBytes = minRequiredBytes
}
// Check if there's enough free space
if diskUsage.FreeSpaceBytes < requiredBytes {
backupSizeGB := float64(backupSizeBytes) / (1024 * 1024 * 1024)
bufferSizeGB := float64(bufferBytes) / (1024 * 1024 * 1024)
requiredGB := float64(requiredBytes) / (1024 * 1024 * 1024)
availableGB := float64(diskUsage.FreeSpaceBytes) / (1024 * 1024 * 1024)
return fmt.Errorf(
"to restore this backup, %.1f GB (%.1f GB backup + %.1f GB buffer) is required, but only %.1f GB is available. Please free up disk space before restoring",
requiredGB,
backupSizeGB,
bufferSizeGB,
availableGB,
)
}
return nil
}

View File

@@ -27,7 +27,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -134,11 +133,16 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
}
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMariadbRestore(
ctx,
@@ -146,7 +150,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin,
args,
myCnfFile,
tempBackupFile,
rawReader,
backup,
)
}
@@ -157,7 +161,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
mariadbBin string,
args []string,
myCnfFile string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -165,16 +169,10 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
cmd := exec.CommandContext(ctx, mariadbBin, fullArgs...)
uc.logger.Info("Executing MariaDB restore command", "command", cmd.String())
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -225,69 +223,6 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
return nil
}
func (uc *RestoreMariadbBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -330,7 +265,7 @@ func (uc *RestoreMariadbBackupUsecase) createTempMyCnfFile(
mdbConfig *mariadbtypes.MariadbDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp("", "mycnf")
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -358,57 +293,6 @@ port=%d
return myCnfFile, nil
}
func (uc *RestoreMariadbBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMariadbBackupUsecase) handleMariadbRestoreError(
database *databases.Database,
waitErr error,

View File

@@ -13,8 +13,6 @@ import (
"strings"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/encryption"
@@ -25,7 +23,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -149,20 +146,26 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
}
}()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMongoRestore(ctx, mongorestoreBin, args, tempBackupFile, backup)
return uc.executeMongoRestore(ctx, mongorestoreBin, args, rawReader, backup)
}
func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
ctx context.Context,
mongorestoreBin string,
args []string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -183,16 +186,10 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
safeArgs,
)
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -232,69 +229,6 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
return nil
}
func (uc *RestoreMongodbBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.archive.gz")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -332,57 +266,6 @@ func (uc *RestoreMongodbBackupUsecase) setupDecryption(
return decryptReader, nil
}
func (uc *RestoreMongodbBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMongodbBackupUsecase) handleMongoRestoreError(
waitErr error,
stderrOutput []byte,

View File

@@ -27,7 +27,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -134,13 +133,18 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
}
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMysqlRestore(ctx, database, mysqlBin, args, myCnfFile, tempBackupFile, backup)
return uc.executeMysqlRestore(ctx, database, mysqlBin, args, myCnfFile, rawReader, backup)
}
func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
@@ -149,7 +153,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
mysqlBin string,
args []string,
myCnfFile string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -157,16 +161,10 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
cmd := exec.CommandContext(ctx, mysqlBin, fullArgs...)
uc.logger.Info("Executing MySQL restore command", "command", cmd.String())
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -217,69 +215,6 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
return nil
}
func (uc *RestoreMysqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -322,7 +257,7 @@ func (uc *RestoreMysqlBackupUsecase) createTempMyCnfFile(
myConfig *mysqltypes.MysqlDatabase,
password string,
) (string, error) {
tempDir, err := os.MkdirTemp("", "mycnf")
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "mycnf_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
@@ -348,57 +283,6 @@ port=%d
return myCnfFile, nil
}
func (uc *RestoreMysqlBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMysqlBackupUsecase) handleMysqlRestoreError(
database *databases.Database,
waitErr error,

View File

@@ -24,7 +24,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
"github.com/google/uuid"
@@ -65,33 +64,298 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
return fmt.Errorf("target database name is required for pg_restore")
}
// Use parallel jobs based on CPU count (same as backup)
// Cap between 1 and 8 to avoid overwhelming the server
parallelJobs := max(1, min(restoringToDB.Postgresql.CpuCount, 8))
pgBin := tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
config.GetEnv().EnvMode,
config.GetEnv().PostgresesInstallDir,
)
// All PostgreSQL backups are now custom format (-Fc)
return uc.restoreCustomType(
originalDB,
pgBin,
backup,
storage,
pg,
isExcludeExtensions,
)
}
// restoreCustomType restores a backup in custom type (-Fc)
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
) error {
uc.logger.Info(
"Restoring backup in custom type (-Fc)",
"backupId",
backup.ID,
"cpuCount",
pg.CpuCount,
)
// If excluding extensions, we must use file-based restore (requires TOC file generation)
// Also use file-based restore for parallel jobs (multiple CPUs)
if isExcludeExtensions || pg.CpuCount > 1 {
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
}
// Single CPU without extension exclusion: stream directly via stdin
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
}
// restoreViaStdin streams backup via stdin for single CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
uc.logger.Info("Restoring via stdin streaming (CPU=1)", "backupId", backup.ID)
args := []string{
"-Fc", // expect custom format (same as backup)
"-j", strconv.Itoa(parallelJobs), // parallel jobs based on CPU count
"--no-password", // Use environment variable for password, prevent prompts
"-Fc", // expect custom type
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose", // Add verbose output to help with debugging
"--clean", // Clean (drop) database objects before recreating them
"--if-exists", // Use IF EXISTS when dropping objects
"--no-owner", // Skip restoring ownership
"--no-acl", // Skip restoring access privileges (GRANT/REVOKE commands)
"--verbose",
"--clean",
"--if-exists",
"--no-owner",
"--no-acl",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
// Create temporary .pgpass file for authentication
fieldEncryptor := util_encryption.GetFieldEncryptor()
decryptedPassword, err := fieldEncryptor.Decrypt(originalDB.ID, pg.Password)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
pgpassFile, err := uc.createTempPgpassFile(pg, decryptedPassword)
if err != nil {
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
}
defer func() {
if pgpassFile != "" {
_ = os.RemoveAll(filepath.Dir(pgpassFile))
}
}()
// Verify .pgpass file was created successfully
if pgpassFile == "" {
return fmt.Errorf("temporary .pgpass file was not created")
}
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return fmt.Errorf("failed to verify .pgpass file: %w", err)
}
// Get backup stream from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
var backupReader io.Reader = rawReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
// Validate encryption metadata
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return fmt.Errorf("backup is encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return fmt.Errorf("failed to get master key for decryption: %w", err)
}
// Decode salt and IV from base64
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
return fmt.Errorf("failed to decode encryption salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
return fmt.Errorf("failed to decode encryption IV: %w", err)
}
// Create decryption reader
decryptReader, err := encryption.NewDecryptionReader(
rawReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
return fmt.Errorf("failed to create decryption reader: %w", err)
}
backupReader = decryptReader
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
}
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL restore command via stdin", "command", cmd.String())
// Setup environment variables
uc.setupPgRestoreEnvironment(cmd, pgpassFile, pg)
// Verify executable exists and is accessible
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf(
"PostgreSQL executable not found or not accessible: %s - %w",
pgBin,
err,
)
}
// Create stdin pipe for explicit data pumping
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe: %w", err)
}
// Get stderr to capture any error output
pgStderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe: %w", err)
}
// Capture stderr in a separate goroutine
stderrCh := make(chan []byte, 1)
go func() {
stderrOutput, _ := io.ReadAll(pgStderr)
stderrCh <- stderrOutput
}()
// Start pg_restore
if err = cmd.Start(); err != nil {
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
}
// Copy backup data to stdin in a separate goroutine with proper error handling
copyErrCh := make(chan error, 1)
go func() {
_, copyErr := io.Copy(stdinPipe, backupReader)
// Close stdin pipe to signal EOF to pg_restore - critical for proper termination
closeErr := stdinPipe.Close()
if copyErr != nil {
copyErrCh <- fmt.Errorf("copy to stdin: %w", copyErr)
} else if closeErr != nil {
copyErrCh <- fmt.Errorf("close stdin: %w", closeErr)
} else {
copyErrCh <- nil
}
}()
// Wait for the restore to finish
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
copyErr := <-copyErrCh
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
// Check for copy errors first - these indicate issues with decryption or data reading
if copyErr != nil {
return fmt.Errorf("failed to stream backup data to pg_restore: %w", copyErr)
}
if waitErr != nil {
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
return uc.handlePgRestoreError(originalDB, waitErr, stderrOutput, pgBin, args, pg)
}
return nil
}
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
) error {
uc.logger.Info(
"Restoring via file with parallel jobs",
"backupId",
backup.ID,
"cpuCount",
pg.CpuCount,
)
// Use parallel jobs based on CPU count
// Cap between 1 and 8 to avoid overwhelming the server
parallelJobs := max(1, min(pg.CpuCount, 8))
args := []string{
"-Fc", // expect custom type
"-j", strconv.Itoa(parallelJobs), // parallel jobs based on CPU count
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose",
"--clean",
"--if-exists",
"--no-owner",
"--no-acl",
}
return uc.restoreFromStorage(
originalDB,
tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
config.GetEnv().EnvMode,
config.GetEnv().PostgresesInstallDir,
),
pgBin,
args,
pg.Password,
backup,
@@ -150,7 +414,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
}
defer func() {
if pgpassFile != "" {
_ = os.Remove(pgpassFile)
_ = os.RemoveAll(filepath.Dir(pgpassFile))
}
}()
@@ -208,13 +472,6 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
// Create temporary directory for backup data
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
@@ -621,7 +878,7 @@ func (uc *RestorePostgresqlBackupUsecase) generateFilteredTocList(
}
// Write filtered TOC to temporary file
tocFile, err := os.CreateTemp("", "pg_restore_toc_*.list")
tocFile, err := os.CreateTemp(config.GetEnv().TempFolder, "pg_restore_toc_*.list")
if err != nil {
return "", fmt.Errorf("failed to create TOC list file: %w", err)
}
@@ -668,7 +925,7 @@ func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
escapedPassword,
)
tempDir, err := os.MkdirTemp("", "pgpass")
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "pgpass_"+uuid.New().String())
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
}

View File

@@ -1,6 +1,8 @@
package storages
import (
"errors"
users_middleware "databasus-backend/internal/features/users/middleware"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"net/http"
@@ -20,6 +22,7 @@ func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/storages/:id", c.GetStorage)
router.DELETE("/storages/:id", c.DeleteStorage)
router.POST("/storages/:id/test", c.TestStorageConnection)
router.POST("/storages/:id/transfer", c.TransferStorageToWorkspace)
router.POST("/storages/direct-test", c.TestStorageConnectionDirect)
}
@@ -55,7 +58,7 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
}
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -93,7 +96,7 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
storage, err := c.storageService.GetStorage(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view storage in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToViewStorage) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -137,7 +140,7 @@ func (c *StorageController) GetStorages(ctx *gin.Context) {
storages, err := c.storageService.GetStorages(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view storages in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToViewStorages) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -174,7 +177,7 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
}
if err := c.storageService.DeleteStorage(user, id); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToManageStorage) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -211,7 +214,7 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
}
if err := c.storageService.TestStorageConnection(user, id); err != nil {
if err.Error() == "insufficient permissions to test storage in this workspace" {
if errors.Is(err, ErrInsufficientPermissionsToTestStorage) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -222,6 +225,57 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "storage connection test successful"})
}
// TransferStorageToWorkspace
// @Summary Transfer storage to another workspace
// @Description Transfer a storage from one workspace to another
// @Tags storages
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param id path string true "Storage ID"
// @Param request body TransferStorageRequest true "Target workspace ID"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id}/transfer [post]
func (c *StorageController) TransferStorageToWorkspace(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
return
}
var request TransferStorageRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.TargetWorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "targetWorkspaceId is required"})
return
}
if err := c.storageService.TransferStorageToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "storage transferred successfully"})
}
// TestStorageConnectionDirect
// @Summary Test storage connection directly
// @Description Test the connection to a storage object provided in the request

View File

@@ -29,6 +29,14 @@ import (
"github.com/stretchr/testify/assert"
)
type mockStorageDatabaseCounter struct{}
func (m *mockStorageDatabaseCounter) GetStorageAttachedDatabasesIDs(
storageID uuid.UUID,
) ([]uuid.UUID, error) {
return []uuid.UUID{}, nil
}
func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
@@ -200,161 +208,161 @@ func Test_TestExistingStorageConnection_ConnectionEstablished(t *testing.T) {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_ViewerCanViewStorages_ButCannotModify(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
viewer,
users_enums.WorkspaceRoleViewer,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
func Test_WorkspaceRolePermissions(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
canCreate bool
canUpdate bool
canDelete bool
}{
{
name: "owner can manage storages",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "admin can manage storages",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "member can manage storages",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
canCreate: true,
canUpdate: true,
canDelete: true,
},
{
name: "viewer can view but cannot modify storages",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
canCreate: false,
canUpdate: false,
canDelete: false,
},
{
name: "global admin can manage storages",
workspaceRole: nil,
isGlobalAdmin: true,
canCreate: true,
canUpdate: true,
canDelete: true,
},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
// Viewer can GET storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+viewer.Token,
http.StatusOK,
&storages,
)
assert.Len(t, storages, 1)
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Viewer cannot CREATE storage
newStorage := createNewStorage(workspace.ID)
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, *newStorage, http.StatusForbidden,
)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(
workspace,
testUser,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = testUser.Token
}
// Viewer cannot UPDATE storage
savedStorage.Name = "Updated by viewer"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, savedStorage, http.StatusForbidden,
)
// Owner creates initial storage for all test cases
var ownerStorage Storage
storage := createNewStorage(workspace.ID)
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", "Bearer "+owner.Token,
*storage, http.StatusOK, &ownerStorage,
)
// Viewer cannot DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+viewer.Token,
http.StatusForbidden,
)
// Test GET storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+testUserToken, http.StatusOK, &storages,
)
assert.Len(t, storages, 1)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
// Test CREATE storage
createStatusCode := http.StatusOK
if !tt.canCreate {
createStatusCode = http.StatusForbidden
}
newStorage := createNewStorage(workspace.ID)
var savedStorage Storage
if tt.canCreate {
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", "Bearer "+testUserToken,
*newStorage, createStatusCode, &savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
} else {
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+testUserToken,
*newStorage, createStatusCode,
)
}
func Test_MemberCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Test UPDATE storage
updateStatusCode := http.StatusOK
if !tt.canUpdate {
updateStatusCode = http.StatusForbidden
}
ownerStorage.Name = "Updated by test user"
if tt.canUpdate {
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", "Bearer "+testUserToken,
ownerStorage, updateStatusCode, &updatedStorage,
)
assert.Equal(t, "Updated by test user", updatedStorage.Name)
} else {
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+testUserToken,
ownerStorage, updateStatusCode,
)
}
// Member can CREATE storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
*storage,
http.StatusOK,
&savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
// Test DELETE storage
deleteStatusCode := http.StatusOK
if !tt.canDelete {
deleteStatusCode = http.StatusForbidden
}
test_utils.MakeDeleteRequest(
t, router,
fmt.Sprintf("/api/v1/storages/%s", ownerStorage.ID.String()),
"Bearer "+testUserToken, deleteStatusCode,
)
// Member can UPDATE storage
savedStorage.Name = "Updated by member"
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
assert.Equal(t, "Updated by member", updatedStorage.Name)
// Member can DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_AdminCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
admin,
users_enums.WorkspaceRoleAdmin,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Admin can CREATE, UPDATE, DELETE
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*storage,
http.StatusOK,
&savedStorage,
)
savedStorage.Name = "Updated by admin"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+admin.Token, savedStorage, http.StatusOK,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Cleanup
if tt.canCreate {
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
}
if !tt.canDelete {
deleteStorage(t, router, ownerStorage.ID, workspace.ID, owner.Token)
}
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
@@ -975,6 +983,184 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
}
func Test_TransferStorage_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
sourceRole *users_enums.WorkspaceRole
targetRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "owner in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "admin in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "member in both workspaces can transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "viewer in both workspaces cannot transfer",
sourceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
targetRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusForbidden,
},
{
name: "global admin can transfer",
sourceRole: nil,
targetRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
sourceOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
targetOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
sourceWorkspace := workspaces_testing.CreateTestWorkspace(
"Source Workspace",
sourceOwner,
router,
)
targetWorkspace := workspaces_testing.CreateTestWorkspace(
"Target Workspace",
targetOwner,
router,
)
storage := createNewStorage(sourceWorkspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+sourceOwner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
testUserToken = testUser.Token
}
request := TransferStorageRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s/transfer", savedStorage.ID.String()),
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
if tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "transferred successfully")
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+targetOwner.Token,
http.StatusOK,
&retrievedStorage,
)
assert.Equal(t, targetWorkspace.ID, retrievedStorage.WorkspaceID)
deleteStorage(t, router, savedStorage.ID, targetWorkspace.ID, targetOwner.Token)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, sourceWorkspace.ID, sourceOwner.Token)
}
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
})
}
}
func Test_TransferStorageNotManagableWorkspace_TransferFailed(t *testing.T) {
router := createRouter()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
userA := users_testing.CreateTestUser(users_enums.UserRoleMember)
userB := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", userA, router)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", userB, router)
storage := createNewStorage(workspace1.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+userA.Token,
*storage,
http.StatusOK,
&savedStorage,
)
request := TransferStorageRequest{
TargetWorkspaceID: workspace2.ID,
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s/transfer", savedStorage.ID.String()),
"Bearer "+userA.Token,
request,
http.StatusForbidden,
)
assert.Contains(
t,
string(testResp.Body),
"insufficient permissions to manage storage in target workspace",
)
deleteStorage(t, router, savedStorage.ID, workspace1.ID, userA.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -989,6 +1175,7 @@ func createRouter() *gin.Engine {
}
audit_logs.SetupDependencies()
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
return router
}

View File

@@ -12,6 +12,7 @@ var storageService = &StorageService{
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
nil,
}
var storageController = &StorageController{
storageService,

View File

@@ -0,0 +1,7 @@
package storages
import "github.com/google/uuid"
type TransferStorageRequest struct {
TargetWorkspaceID uuid.UUID `json:"targetWorkspaceId" binding:"required"`
}

View File

@@ -0,0 +1,36 @@
package storages
import "errors"
var (
ErrInsufficientPermissionsToManageStorage = errors.New(
"insufficient permissions to manage storage in this workspace",
)
ErrInsufficientPermissionsToViewStorage = errors.New(
"insufficient permissions to view storage in this workspace",
)
ErrInsufficientPermissionsToViewStorages = errors.New(
"insufficient permissions to view storages in this workspace",
)
ErrInsufficientPermissionsToTestStorage = errors.New(
"insufficient permissions to test storage in this workspace",
)
ErrInsufficientPermissionsInSourceWorkspace = errors.New(
"insufficient permissions to manage storage in source workspace",
)
ErrInsufficientPermissionsInTargetWorkspace = errors.New(
"insufficient permissions to manage storage in target workspace",
)
ErrStorageDoesNotBelongToWorkspace = errors.New(
"storage does not belong to this workspace",
)
ErrStorageHasAttachedDatabases = errors.New(
"storage has attached databases and cannot be deleted",
)
ErrStorageHasAttachedDatabasesCannotTransfer = errors.New(
"storage has attached databases and cannot be transferred",
)
ErrStorageHasOtherAttachedDatabasesCannotTransfer = errors.New(
"storage has other attached databases and cannot be transferred",
)
)

View File

@@ -30,3 +30,7 @@ type StorageFileSaver interface {
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}
type StorageDatabaseCounter interface {
GetStorageAttachedDatabasesIDs(storageID uuid.UUID) ([]uuid.UUID, error)
}

View File

@@ -26,6 +26,7 @@ const (
azureResponseTimeout = 30 * time.Second
azureIdleConnTimeout = 90 * time.Second
azureTLSHandshakeTimeout = 30 * time.Second
azureDeleteTimeout = 30 * time.Second
// Chunk size for block blob uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -186,8 +187,11 @@ func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileI
blobName := s.buildBlobName(fileID.String())
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
defer cancel()
_, err = client.DeleteBlob(
context.TODO(),
ctx,
s.ContainerName,
blobName,
nil,

View File

@@ -18,6 +18,7 @@ import (
const (
ftpConnectTimeout = 30 * time.Second
ftpTestConnectTimeout = 10 * time.Second
ftpDeleteTimeout = 30 * time.Second
ftpChunkSize = 16 * 1024 * 1024
)
@@ -134,7 +135,10 @@ func (f *FTPStorage) GetFile(
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
conn, err := f.connect(encryptor, ftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
defer cancel()
conn, err := f.connectWithContext(ctx, encryptor, ftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to FTP: %w", err)
}

View File

@@ -27,6 +27,7 @@ const (
gdResponseTimeout = 30 * time.Second
gdIdleConnTimeout = 90 * time.Second
gdTLSHandshakeTimeout = 30 * time.Second
gdDeleteTimeout = 30 * time.Second
// Chunk size for Google Drive resumable uploads - 16MB provides good balance
// between memory usage and upload efficiency. Google Drive requires chunks
@@ -185,7 +186,9 @@ func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
defer cancel()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {

View File

@@ -18,6 +18,8 @@ import (
)
const (
nasDeleteTimeout = 30 * time.Second
// Chunk size for NAS uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for NAS to confirm receipt.
@@ -193,7 +195,10 @@ func (n *NASStorage) GetFile(
}
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
defer cancel()
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -211,10 +216,8 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
filePath := n.getFilePath(fileID.String())
// Check if file exists before trying to delete
_, err = fs.Stat(filePath)
if err != nil {
// File doesn't exist, consider it already deleted
return nil
}

View File

@@ -22,6 +22,7 @@ import (
const (
rcloneOperationTimeout = 30 * time.Second
rcloneDeleteTimeout = 30 * time.Second
)
var rcloneConfigMu sync.Mutex
@@ -115,7 +116,8 @@ func (r *RcloneStorage) GetFile(
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
defer cancel()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {

View File

@@ -26,6 +26,7 @@ const (
s3ResponseTimeout = 30 * time.Second
s3IdleConnTimeout = 90 * time.Second
s3TLSHandshakeTimeout = 30 * time.Second
s3DeleteTimeout = 30 * time.Second
// Chunk size for multipart uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -228,9 +229,11 @@ func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.
objectKey := s.buildObjectKey(fileID.String())
// Delete the object using MinIO client
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
defer cancel()
err = client.RemoveObject(
context.TODO(),
ctx,
s.S3Bucket,
objectKey,
minio.RemoveObjectOptions{},
@@ -356,7 +359,7 @@ func (s *S3Storage) Update(incoming *S3Storage) {
}
// we do not allow to change the prefix after creation,
// otherwise we will have to migrate all the data to the new prefix
// otherwise we will have to transfer all the data to the new prefix
}
func (s *S3Storage) buildObjectKey(fileName string) string {

View File

@@ -19,6 +19,7 @@ import (
const (
sftpConnectTimeout = 30 * time.Second
sftpTestConnectTimeout = 10 * time.Second
sftpDeleteTimeout = 30 * time.Second
)
type SFTPStorage struct {
@@ -154,7 +155,10 @@ func (s *SFTPStorage) GetFile(
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
defer cancel()
client, sshConn, err := s.connectWithContext(ctx, encryptor, sftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to SFTP: %w", err)
}

View File

@@ -1,7 +1,6 @@
package storages
import (
"errors"
"fmt"
audit_logs "databasus-backend/internal/features/audit_logs"
@@ -13,10 +12,15 @@ import (
)
type StorageService struct {
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
storageDatabaseCounter StorageDatabaseCounter
}
func (s *StorageService) SetStorageDatabaseCounter(storageDatabaseCounter StorageDatabaseCounter) {
s.storageDatabaseCounter = storageDatabaseCounter
}
func (s *StorageService) SaveStorage(
@@ -29,7 +33,7 @@ func (s *StorageService) SaveStorage(
return err
}
if !canManage {
return errors.New("insufficient permissions to manage storage in this workspace")
return ErrInsufficientPermissionsToManageStorage
}
isUpdate := storage.ID != uuid.Nil
@@ -41,7 +45,7 @@ func (s *StorageService) SaveStorage(
}
if existingStorage.WorkspaceID != workspaceID {
return errors.New("storage does not belong to this workspace")
return ErrStorageDoesNotBelongToWorkspace
}
existingStorage.Update(storage)
@@ -104,7 +108,15 @@ func (s *StorageService) DeleteStorage(
return err
}
if !canManage {
return errors.New("insufficient permissions to manage storage in this workspace")
return ErrInsufficientPermissionsToManageStorage
}
attachedDatabasesIDs, err := s.storageDatabaseCounter.GetStorageAttachedDatabasesIDs(storage.ID)
if err != nil {
return err
}
if len(attachedDatabasesIDs) > 0 {
return ErrStorageHasAttachedDatabases
}
err = s.storageRepository.Delete(storage)
@@ -135,7 +147,7 @@ func (s *StorageService) GetStorage(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view storage in this workspace")
return nil, ErrInsufficientPermissionsToViewStorage
}
storage.HideSensitiveData()
@@ -152,7 +164,7 @@ func (s *StorageService) GetStorages(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view storages in this workspace")
return nil, ErrInsufficientPermissionsToViewStorages
}
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
@@ -181,7 +193,7 @@ func (s *StorageService) TestStorageConnection(
return err
}
if !canView {
return errors.New("insufficient permissions to test storage in this workspace")
return ErrInsufficientPermissionsToTestStorage
}
err = storage.TestConnection(s.fieldEncryptor)
@@ -212,7 +224,7 @@ func (s *StorageService) TestStorageConnectionDirect(
}
if existingStorage.WorkspaceID != storage.WorkspaceID {
return errors.New("storage does not belong to this workspace")
return ErrStorageDoesNotBelongToWorkspace
}
existingStorage.Update(storage)
@@ -235,6 +247,70 @@ func (s *StorageService) GetStorageByID(
return s.storageRepository.FindByID(id)
}
func (s *StorageService) TransferStorageToWorkspace(
user *users_models.User,
storageID uuid.UUID,
targetWorkspaceID uuid.UUID,
transferingWithDbID *uuid.UUID,
) error {
existingStorage, err := s.storageRepository.FindByID(storageID)
if err != nil {
return err
}
canManageSource, err := s.workspaceService.CanUserManageDBs(existingStorage.WorkspaceID, user)
if err != nil {
return err
}
if !canManageSource {
return ErrInsufficientPermissionsInSourceWorkspace
}
canManageTarget, err := s.workspaceService.CanUserManageDBs(targetWorkspaceID, user)
if err != nil {
return err
}
if !canManageTarget {
return ErrInsufficientPermissionsInTargetWorkspace
}
attachedDatabasesIDs, err := s.storageDatabaseCounter.GetStorageAttachedDatabasesIDs(
existingStorage.ID,
)
if err != nil {
return err
}
if transferingWithDbID != nil {
for _, dbID := range attachedDatabasesIDs {
if dbID != *transferingWithDbID {
return ErrStorageHasOtherAttachedDatabasesCannotTransfer
}
}
} else {
if len(attachedDatabasesIDs) > 0 {
return ErrStorageHasAttachedDatabasesCannotTransfer
}
}
sourceWorkspaceID := existingStorage.WorkspaceID
existingStorage.WorkspaceID = targetWorkspaceID
_, err = s.storageRepository.Save(existingStorage)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage transferred: %s from workspace %s to workspace %s",
existingStorage.Name, sourceWorkspaceID, targetWorkspaceID),
&user.ID,
&targetWorkspaceID,
)
return nil
}
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
if err != nil {

View File

@@ -32,21 +32,23 @@ import (
test_utils "databasus-backend/internal/util/testing"
)
const createAndFillTableQuery = `
DROP TABLE IF EXISTS test_data;
func createAndFillTableQuery(tableName string) string {
return fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
CREATE TABLE test_data (
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
INSERT INTO test_data (name, value) VALUES
INSERT INTO %s (name, value) VALUES
('test1', 100),
('test2', 200),
('test3', 300);
`
`, tableName, tableName, tableName)
}
type PostgresContainer struct {
Host string
@@ -68,23 +70,31 @@ type TestDataItem struct {
func Test_BackupAndRestorePostgresql_RestoreIsSuccesful(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
name string
version string
port string
cpuCount int
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
{"PostgreSQL 18", "18", env.TestPostgres18Port},
{"PostgreSQL 12 (CPU=1 streamed)", "12", env.TestPostgres12Port, 1},
{"PostgreSQL 12 (CPU=4 directory)", "12", env.TestPostgres12Port, 4},
{"PostgreSQL 13 (CPU=1 streamed)", "13", env.TestPostgres13Port, 1},
{"PostgreSQL 13 (CPU=4 directory)", "13", env.TestPostgres13Port, 4},
{"PostgreSQL 14 (CPU=1 streamed)", "14", env.TestPostgres14Port, 1},
{"PostgreSQL 14 (CPU=4 directory)", "14", env.TestPostgres14Port, 4},
{"PostgreSQL 15 (CPU=1 streamed)", "15", env.TestPostgres15Port, 1},
{"PostgreSQL 15 (CPU=4 directory)", "15", env.TestPostgres15Port, 4},
{"PostgreSQL 16 (CPU=1 streamed)", "16", env.TestPostgres16Port, 1},
{"PostgreSQL 16 (CPU=4 directory)", "16", env.TestPostgres16Port, 4},
{"PostgreSQL 17 (CPU=1 streamed)", "17", env.TestPostgres17Port, 1},
{"PostgreSQL 17 (CPU=4 directory)", "17", env.TestPostgres17Port, 4},
{"PostgreSQL 18 (CPU=1 streamed)", "18", env.TestPostgres18Port, 1},
{"PostgreSQL 18 (CPU=4 directory)", "18", env.TestPostgres18Port, 4},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testBackupRestoreForVersion(t, tc.version, tc.port)
testBackupRestoreForVersion(t, tc.version, tc.port, tc.cpuCount)
})
}
}
@@ -361,7 +371,7 @@ func Test_BackupAndRestorePostgresql_WithReadOnlyUser_RestoreIsSuccessful(t *tes
}
}
func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cpuCount int) {
container, err := connectToPostgresContainer(pgVersion, port)
assert.NoError(t, err)
defer func() {
@@ -370,19 +380,25 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
database := createDatabaseViaAPI(
database := createDatabaseWithCpuCountViaAPI(
t, router, "Test Database", workspace.ID,
container.Host, container.Port,
container.Username, container.Password, container.Database,
cpuCount,
user.Token,
)
@@ -396,23 +412,28 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb"
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d_%s", pgVersion, cpuCount, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
assert.NoError(t, err)
defer newDB.Close()
createRestoreViaAPI(
createRestoreWithCpuCountViaAPI(
t, router, backup.ID,
container.Host, container.Port,
container.Username, container.Password, newDBName,
cpuCount,
user.Token,
)
@@ -422,12 +443,19 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -501,13 +529,17 @@ func testSchemaSelectionAllSchemasForVersion(t *testing.T, pgVersion string, por
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restored_all_schemas_" + pgVersion
newDBName := fmt.Sprintf("restored_all_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -625,14 +657,17 @@ func testBackupRestoreWithExcludeExtensionsForVersion(t *testing.T, pgVersion st
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
// Create new database for restore with extension pre-installed
newDBName := "restored_exclude_ext_" + pgVersion
newDBName := fmt.Sprintf("restored_exclude_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -756,14 +791,17 @@ func testBackupRestoreWithoutExcludeExtensionsForVersion(
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
// Create new database for restore WITHOUT pre-installed extension
newDBName := "restored_with_ext_" + pgVersion
newDBName := fmt.Sprintf("restored_with_ext_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -851,9 +889,14 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("ReadOnly Test Workspace", user, router)
@@ -887,13 +930,17 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_readonly"
newDBName := fmt.Sprintf("restoreddb_readonly_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -913,12 +960,19 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -996,13 +1050,17 @@ func testSchemaSelectionOnlySpecifiedSchemasForVersion(
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restored_specific_schemas_" + pgVersion
newDBName := fmt.Sprintf("restored_specific_schemas_%s_%s", pgVersion, uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -1074,9 +1132,14 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -1101,13 +1164,17 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_encrypted"
newDBName := fmt.Sprintf("restoreddb_encrypted_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
}()
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
@@ -1127,12 +1194,19 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -1258,6 +1332,27 @@ func createDatabaseViaAPI(
password string,
database string,
token string,
) *databases.Database {
return createDatabaseWithCpuCountViaAPI(
t, router, name, workspaceID,
host, port, username, password, database,
1,
token,
)
}
func createDatabaseWithCpuCountViaAPI(
t *testing.T,
router *gin.Engine,
name string,
workspaceID uuid.UUID,
host string,
port int,
username string,
password string,
database string,
cpuCount int,
token string,
) *databases.Database {
request := databases.Database{
Name: name,
@@ -1269,7 +1364,7 @@ func createDatabaseViaAPI(
Username: username,
Password: password,
Database: &database,
CpuCount: 1,
CpuCount: cpuCount,
},
}
@@ -1354,7 +1449,7 @@ func createRestoreViaAPI(
database string,
token string,
) {
createRestoreWithOptionsViaAPI(
createRestoreWithCpuCountViaAPI(
t,
router,
backupID,
@@ -1363,11 +1458,44 @@ func createRestoreViaAPI(
username,
password,
database,
false,
1,
token,
)
}
func createRestoreWithCpuCountViaAPI(
t *testing.T,
router *gin.Engine,
backupID uuid.UUID,
host string,
port int,
username string,
password string,
database string,
cpuCount int,
token string,
) {
request := restores.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
Username: username,
Password: password,
Database: &database,
CpuCount: cpuCount,
},
}
test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backupID.String()),
"Bearer "+token,
request,
http.StatusOK,
)
}
func createRestoreWithOptionsViaAPI(
t *testing.T,
router *gin.Engine,
@@ -1540,14 +1668,14 @@ func createSupabaseRestoreViaAPI(
)
}
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB, tableName string) {
var originalData []TestDataItem
var restoredData []TestDataItem
err := originalDB.Select(&originalData, "SELECT * FROM test_data ORDER BY id")
err := originalDB.Select(&originalData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
err = restoredDB.Select(&restoredData, "SELECT * FROM test_data ORDER BY id")
err = restoredDB.Select(&restoredData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")

View File

@@ -1,10 +1,12 @@
package users_controllers
import (
"errors"
"net/http"
"databasus-backend/internal/config"
user_dto "databasus-backend/internal/features/users/dto"
users_errors "databasus-backend/internal/features/users/errors"
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
@@ -206,7 +208,7 @@ func (c *UserController) InviteUser(ctx *gin.Context) {
response, err := c.userService.InviteUser(&request, user)
if err != nil {
if err.Error() == "insufficient permissions to invite users" {
if errors.Is(err, users_errors.ErrInsufficientPermissionsToInviteUsers) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}

View File

@@ -0,0 +1,7 @@
package users_errors
import "errors"
var (
ErrInsufficientPermissionsToInviteUsers = errors.New("insufficient permissions to invite users")
)

View File

@@ -20,6 +20,7 @@ import (
"databasus-backend/internal/features/encryption/secrets"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_errors "databasus-backend/internal/features/users/errors"
users_interfaces "databasus-backend/internal/features/users/interfaces"
users_models "databasus-backend/internal/features/users/models"
users_repositories "databasus-backend/internal/features/users/repositories"
@@ -340,7 +341,7 @@ func (s *UserService) InviteUser(
// Check if user has permission to invite
if !invitedBy.CanInviteUsers(settings) {
return nil, errors.New("insufficient permissions to invite users")
return nil, users_errors.ErrInsufficientPermissionsToInviteUsers
}
// Check if user already exists

View File

@@ -1,10 +1,12 @@
package workspaces_controllers
import (
"errors"
"net/http"
users_middleware "databasus-backend/internal/features/users/middleware"
workspaces_dto "databasus-backend/internal/features/workspaces/dto"
workspaces_errors "databasus-backend/internal/features/workspaces/errors"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
@@ -53,7 +55,7 @@ func (c *MembershipController) ListMembers(ctx *gin.Context) {
response, err := c.membershipService.GetMembers(workspaceID, user)
if err != nil {
if err.Error() == "insufficient permissions to view workspace members" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToViewMembers) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -105,8 +107,8 @@ func (c *MembershipController) AddMember(ctx *gin.Context) {
response, err := c.membershipService.AddMember(workspaceID, &request, user)
if err != nil {
if err.Error() == "insufficient permissions to manage members" ||
err.Error() == "only workspace owner can add/manage admins" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToManageMembers) ||
errors.Is(err, workspaces_errors.ErrOnlyOwnerCanAddManageAdmins) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -160,8 +162,8 @@ func (c *MembershipController) ChangeMemberRole(ctx *gin.Context) {
}
if err := c.membershipService.ChangeMemberRole(workspaceID, userID, &request, user); err != nil {
if err.Error() == "insufficient permissions to manage members" ||
err.Error() == "only workspace owner can add/manage admins" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToManageMembers) ||
errors.Is(err, workspaces_errors.ErrOnlyOwnerCanAddManageAdmins) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -206,8 +208,8 @@ func (c *MembershipController) RemoveMember(ctx *gin.Context) {
}
if err := c.membershipService.RemoveMember(workspaceID, userID, user); err != nil {
if err.Error() == "insufficient permissions to remove members" ||
err.Error() == "only workspace owner can remove admins" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToRemoveMembers) ||
errors.Is(err, workspaces_errors.ErrOnlyOwnerCanRemoveAdmins) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -253,7 +255,7 @@ func (c *MembershipController) TransferOwnership(ctx *gin.Context) {
}
if err := c.membershipService.TransferOwnership(workspaceID, &request, user); err != nil {
if err.Error() == "only workspace owner or admin can transfer ownership" {
if errors.Is(err, workspaces_errors.ErrOnlyOwnerOrAdminCanTransferOwnership) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}

View File

@@ -1,11 +1,13 @@
package workspaces_controllers
import (
"errors"
"net/http"
audit_logs "databasus-backend/internal/features/audit_logs"
users_middleware "databasus-backend/internal/features/users/middleware"
workspaces_dto "databasus-backend/internal/features/workspaces/dto"
workspaces_errors "databasus-backend/internal/features/workspaces/errors"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -56,7 +58,7 @@ func (c *WorkspaceController) CreateWorkspace(ctx *gin.Context) {
response, err := c.workspaceService.CreateWorkspace(&request, user)
if err != nil {
if err.Error() == "insufficient permissions to create workspaces" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToCreateWorkspaces) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -121,7 +123,7 @@ func (c *WorkspaceController) GetWorkspace(ctx *gin.Context) {
workspace, err := c.workspaceService.GetWorkspace(workspaceID, user)
if err != nil {
if err.Error() == "insufficient permissions to view workspace" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToViewWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -168,7 +170,7 @@ func (c *WorkspaceController) UpdateWorkspace(ctx *gin.Context) {
updatedWorkspace, err := c.workspaceService.UpdateWorkspace(workspaceID, &workspace, user)
if err != nil {
if err.Error() == "insufficient permissions to update workspace" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToUpdateWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -205,7 +207,7 @@ func (c *WorkspaceController) DeleteWorkspace(ctx *gin.Context) {
}
if err := c.workspaceService.DeleteWorkspace(workspaceID, user); err != nil {
if err.Error() == "only workspace owner or admin can delete workspace" {
if errors.Is(err, workspaces_errors.ErrOnlyOwnerOrAdminCanDeleteWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
@@ -254,7 +256,7 @@ func (c *WorkspaceController) GetWorkspaceAuditLogs(ctx *gin.Context) {
response, err := c.workspaceService.GetWorkspaceAuditLogs(workspaceID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view workspace audit logs" {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToViewWorkspaceAuditLogs) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}

View File

@@ -0,0 +1,56 @@
package workspaces_errors
import "errors"
var (
// Workspace errors
ErrInsufficientPermissionsToCreateWorkspaces = errors.New(
"insufficient permissions to create workspaces",
)
ErrInsufficientPermissionsToViewWorkspace = errors.New(
"insufficient permissions to view workspace",
)
ErrInsufficientPermissionsToUpdateWorkspace = errors.New(
"insufficient permissions to update workspace",
)
ErrInsufficientPermissionsToViewWorkspaceAuditLogs = errors.New(
"insufficient permissions to view workspace audit logs",
)
ErrOnlyOwnerOrAdminCanDeleteWorkspace = errors.New(
"only workspace owner or admin can delete workspace",
)
// Membership errors
ErrInsufficientPermissionsToViewMembers = errors.New(
"insufficient permissions to view workspace members",
)
ErrInsufficientPermissionsToManageMembers = errors.New(
"insufficient permissions to manage members",
)
ErrInsufficientPermissionsToRemoveMembers = errors.New(
"insufficient permissions to remove members",
)
ErrInsufficientPermissionsToInviteUsers = errors.New(
"insufficient permissions to invite users",
)
ErrOnlyOwnerCanAddManageAdmins = errors.New(
"only workspace owner can add/manage admins",
)
ErrOnlyOwnerCanRemoveAdmins = errors.New("only workspace owner can remove admins")
ErrOnlyOwnerOrAdminCanTransferOwnership = errors.New(
"only workspace owner or admin can transfer ownership",
)
ErrUserAlreadyMember = errors.New(
"user is already a member of this workspace",
)
ErrCannotChangeOwnRole = errors.New("cannot change your own role")
ErrUserNotMemberOfWorkspace = errors.New("user is not a member of this workspace")
ErrCannotChangeOwnerRole = errors.New("cannot change owner role")
ErrUserNotFound = errors.New("user not found")
ErrCannotRemoveWorkspaceOwner = errors.New(
"cannot remove workspace owner, transfer ownership first",
)
ErrNewOwnerNotFound = errors.New("new owner not found")
ErrNewOwnerMustBeMember = errors.New("new owner must be a workspace member")
ErrNoCurrentWorkspaceOwner = errors.New("no current workspace owner found")
)

View File

@@ -1,7 +1,6 @@
package workspaces_services
import (
"errors"
"fmt"
audit_logs "databasus-backend/internal/features/audit_logs"
@@ -10,6 +9,7 @@ import (
users_models "databasus-backend/internal/features/users/models"
users_services "databasus-backend/internal/features/users/services"
workspaces_dto "databasus-backend/internal/features/workspaces/dto"
workspaces_errors "databasus-backend/internal/features/workspaces/errors"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_repositories "databasus-backend/internal/features/workspaces/repositories"
@@ -34,7 +34,7 @@ func (s *MembershipService) GetMembers(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view workspace members")
return nil, workspaces_errors.ErrInsufficientPermissionsToViewMembers
}
members, err := s.membershipRepository.GetWorkspaceMembers(workspaceID)
@@ -74,7 +74,7 @@ func (s *MembershipService) AddMember(
}
if !addedBy.CanInviteUsers(settings) {
return nil, errors.New("insufficient permissions to invite users")
return nil, workspaces_errors.ErrInsufficientPermissionsToInviteUsers
}
inviteRequest := &users_dto.InviteUserRequestDTO{
@@ -118,7 +118,7 @@ func (s *MembershipService) AddMember(
workspaceID,
)
if existingMembership != nil {
return nil, errors.New("user is already a member of this workspace")
return nil, workspaces_errors.ErrUserAlreadyMember
}
membership := &workspaces_models.WorkspaceMembership{
@@ -153,7 +153,7 @@ func (s *MembershipService) ChangeMemberRole(
}
if memberUserID == changedBy.ID {
return errors.New("cannot change your own role")
return workspaces_errors.ErrCannotChangeOwnRole
}
existingMembership, err := s.membershipRepository.GetMembershipByUserAndWorkspace(
@@ -161,16 +161,16 @@ func (s *MembershipService) ChangeMemberRole(
workspaceID,
)
if err != nil {
return errors.New("user is not a member of this workspace")
return workspaces_errors.ErrUserNotMemberOfWorkspace
}
if existingMembership.Role == users_enums.WorkspaceRoleOwner {
return errors.New("cannot change owner role")
return workspaces_errors.ErrCannotChangeOwnerRole
}
targetUser, err := s.userService.GetUserByID(memberUserID)
if err != nil {
return errors.New("user not found")
return workspaces_errors.ErrUserNotFound
}
if err := s.membershipRepository.UpdateMemberRole(memberUserID, workspaceID, request.Role); err != nil {
@@ -202,7 +202,7 @@ func (s *MembershipService) RemoveMember(
}
if !canManage {
return errors.New("insufficient permissions to remove members")
return workspaces_errors.ErrInsufficientPermissionsToRemoveMembers
}
existingMembership, err := s.membershipRepository.GetMembershipByUserAndWorkspace(
@@ -210,11 +210,11 @@ func (s *MembershipService) RemoveMember(
workspaceID,
)
if err != nil {
return errors.New("user is not a member of this workspace")
return workspaces_errors.ErrUserNotMemberOfWorkspace
}
if existingMembership.Role == users_enums.WorkspaceRoleOwner {
return errors.New("cannot remove workspace owner, transfer ownership first")
return workspaces_errors.ErrCannotRemoveWorkspaceOwner
}
if existingMembership.Role == users_enums.WorkspaceRoleAdmin {
@@ -223,13 +223,13 @@ func (s *MembershipService) RemoveMember(
return err
}
if !canManageAdmins {
return errors.New("only workspace owner can remove admins")
return workspaces_errors.ErrOnlyOwnerCanRemoveAdmins
}
}
targetUser, err := s.userService.GetUserByID(memberUserID)
if err != nil {
return errors.New("user not found")
return workspaces_errors.ErrUserNotFound
}
if err := s.membershipRepository.RemoveMember(memberUserID, workspaceID); err != nil {
@@ -257,21 +257,21 @@ func (s *MembershipService) TransferOwnership(
if user.Role != users_enums.UserRoleAdmin &&
(currentRole == nil || *currentRole != users_enums.WorkspaceRoleOwner) {
return errors.New("only workspace owner or admin can transfer ownership")
return workspaces_errors.ErrOnlyOwnerOrAdminCanTransferOwnership
}
newOwner, err := s.userService.GetUserByEmail(request.NewOwnerEmail)
if err != nil {
return errors.New("new owner not found")
return workspaces_errors.ErrNewOwnerNotFound
}
if newOwner == nil {
return errors.New("new owner not found")
return workspaces_errors.ErrNewOwnerNotFound
}
_, err = s.membershipRepository.GetMembershipByUserAndWorkspace(newOwner.ID, workspaceID)
if err != nil {
return errors.New("new owner must be a workspace member")
return workspaces_errors.ErrNewOwnerMustBeMember
}
currentOwner, err := s.membershipRepository.GetWorkspaceOwner(workspaceID)
@@ -280,7 +280,7 @@ func (s *MembershipService) TransferOwnership(
}
if currentOwner == nil {
return errors.New("no current workspace owner found")
return workspaces_errors.ErrNoCurrentWorkspaceOwner
}
if err := s.membershipRepository.UpdateMemberRole(newOwner.ID, workspaceID, users_enums.WorkspaceRoleOwner); err != nil {
@@ -311,7 +311,7 @@ func (s *MembershipService) validateCanManageMembership(
return err
}
if !canManageAdmins {
return errors.New("only workspace owner can add/manage admins")
return workspaces_errors.ErrOnlyOwnerCanAddManageAdmins
}
return nil
}
@@ -322,7 +322,7 @@ func (s *MembershipService) validateCanManageMembership(
}
if !canManageMembership {
return errors.New("insufficient permissions to manage members")
return workspaces_errors.ErrInsufficientPermissionsToManageMembers
}
return nil

View File

@@ -1,7 +1,6 @@
package workspaces_services
import (
"errors"
"fmt"
"time"
@@ -10,6 +9,7 @@ import (
users_models "databasus-backend/internal/features/users/models"
users_services "databasus-backend/internal/features/users/services"
workspaces_dto "databasus-backend/internal/features/workspaces/dto"
workspaces_errors "databasus-backend/internal/features/workspaces/errors"
workspaces_interfaces "databasus-backend/internal/features/workspaces/interfaces"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_repositories "databasus-backend/internal/features/workspaces/repositories"
@@ -43,7 +43,7 @@ func (s *WorkspaceService) CreateWorkspace(
}
if !creator.CanCreateWorkspaces(settings) {
return nil, errors.New("insufficient permissions to create workspaces")
return nil, workspaces_errors.ErrInsufficientPermissionsToCreateWorkspaces
}
workspace := &workspaces_models.Workspace{
@@ -91,7 +91,7 @@ func (s *WorkspaceService) GetWorkspace(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view workspace")
return nil, workspaces_errors.ErrInsufficientPermissionsToViewWorkspace
}
return s.workspaceRepository.GetWorkspaceByID(workspaceID)
@@ -121,7 +121,7 @@ func (s *WorkspaceService) UpdateWorkspace(
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to update workspace")
return nil, workspaces_errors.ErrInsufficientPermissionsToUpdateWorkspace
}
existingWorkspace, err := s.workspaceRepository.GetWorkspaceByID(workspaceID)
@@ -155,7 +155,7 @@ func (s *WorkspaceService) DeleteWorkspace(workspaceID uuid.UUID, user *users_mo
}
if userWorkspaceRole == nil || *userWorkspaceRole != users_enums.WorkspaceRoleOwner {
return errors.New("only workspace owner or admin can delete workspace")
return workspaces_errors.ErrOnlyOwnerOrAdminCanDeleteWorkspace
}
}
@@ -299,7 +299,7 @@ func (s *WorkspaceService) GetWorkspaceAuditLogs(
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view workspace audit logs")
return nil, workspaces_errors.ErrInsufficientPermissionsToViewWorkspaceAuditLogs
}
return s.auditLogService.GetWorkspaceAuditLogs(workspaceID, request)

View File

@@ -0,0 +1,59 @@
package encryption
import (
"crypto/rand"
"math/big"
)
// GenerateComplexPassword creates a password that meets common cloud provider requirements:
// - At least one lowercase letter
// - At least one uppercase letter
// - At least one digit
// - At least one special character
// - 24 characters for security
func GenerateComplexPassword() string {
const (
lowercase = "abcdefghijklmnopqrstuvwxyz"
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
digits = "0123456789"
special = "!@#$%^&*()-_=+"
all = lowercase + uppercase + digits + special
)
password := make([]byte, 24)
// Ensure at least one character from each required set
password[0] = randomChar(lowercase)
password[1] = randomChar(uppercase)
password[2] = randomChar(digits)
password[3] = randomChar(special)
// Fill the rest with random characters from all sets
for i := 4; i < len(password); i++ {
password[i] = randomChar(all)
}
// Shuffle the password to avoid predictable positions
shuffleBytes(password)
return string(password)
}
func randomChar(charset string) byte {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return charset[0]
}
return charset[n.Int64()]
}
func shuffleBytes(b []byte) {
for i := len(b) - 1; i > 0; i-- {
n, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil {
continue
}
j := n.Int64()
b[i], b[j] = b[j], b[i]
}
}

View File

@@ -176,6 +176,13 @@ func getPostgresqlBasePath(
postgresesInstallDir string,
) string {
if envMode == env_utils.EnvModeDevelopment {
// On Windows, PostgreSQL 12 and 13 have issues with piping over restore
if runtime.GOOS == "windows" {
if version == PostgresqlVersion12 || version == PostgresqlVersion13 {
version = PostgresqlVersion14
}
}
return filepath.Join(
postgresesInstallDir,
fmt.Sprintf("postgresql-%s", string(version)),

View File

@@ -0,0 +1,9 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backups ADD COLUMN type TEXT NOT NULL DEFAULT 'DEFAULT';
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backups DROP COLUMN type;
-- +goose StatementEnd

View File

@@ -0,0 +1,9 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backups DROP COLUMN type;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backups ADD COLUMN type TEXT NOT NULL DEFAULT 'DEFAULT';
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
DROP COLUMN is_exclude_events;
-- +goose StatementEnd

View File

@@ -0,0 +1,23 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mysql_databases
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
ALTER TABLE mariadb_databases
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
ALTER TABLE mariadb_databases
DROP COLUMN is_exclude_events;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE mariadb_databases
DROP COLUMN privileges;
ALTER TABLE mysql_databases
DROP COLUMN privileges;
-- +goose StatementEnd

View File

@@ -0,0 +1,43 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE notifiers
DROP CONSTRAINT fk_notifiers_workspace_id;
ALTER TABLE notifiers
ADD CONSTRAINT fk_notifiers_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id)
ON DELETE CASCADE;
ALTER TABLE storages
DROP CONSTRAINT fk_storages_workspace_id;
ALTER TABLE storages
ADD CONSTRAINT fk_storages_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id)
ON DELETE CASCADE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE notifiers
DROP CONSTRAINT fk_notifiers_workspace_id;
ALTER TABLE notifiers
ADD CONSTRAINT fk_notifiers_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id);
ALTER TABLE storages
DROP CONSTRAINT fk_storages_workspace_id;
ALTER TABLE storages
ADD CONSTRAINT fk_storages_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id);
-- +goose StatementEnd

View File

@@ -1 +0,0 @@
This is test data for storage testing

Some files were not shown because too many files have changed in this diff Show More