Compare commits

...

25 Commits

Author SHA1 Message Date
Rostislav Dugin
cba40afd00 Merge pull request #228 from databasus/develop
FIX (backend): Fix formatting
2026-01-08 17:11:43 +03:00
Rostislav Dugin
7aea012aeb FIX (backend): Fix formatting 2026-01-08 17:10:47 +03:00
Rostislav Dugin
6d5534deaa Merge pull request #227 from databasus/develop
Develop
2026-01-08 16:55:12 +03:00
Rostislav Dugin
c04bd54683 FIX (download): Add streamable download of backups 2026-01-08 15:55:52 +03:00
Rostislav Dugin
1c3f16b372 FIX (google drive): Fix UI after new local redirect PR 2026-01-08 12:22:47 +03:00
Rostislav Dugin
ed08da56a6 FIX (cicd): Get rid of CITATION auto generate 2026-01-08 11:35:55 +03:00
Rostislav Dugin
c53e84b48d FIX (devex): Fix Linux tools installation script 2026-01-08 11:34:35 +03:00
Rostislav Dugin
dbfeb9e27f merge develop 2026-01-08 11:33:34 +03:00
Rostislav Dugin
02e86ffb3b FIX (devex): Fix Linux tools installation script 2026-01-08 11:10:56 +03:00
github-actions[bot]
207382116c Update CITATION.cff to v2.21.0 2026-01-05 18:38:28 +00:00
Rostislav Dugin
a91ee50e31 Merge pull request #221 from databasus/develop
Develop
2026-01-05 21:08:50 +03:00
Rostislav Dugin
7e5562b115 FEATURE (mysql): Add automatic detection of allowed privileges to backup proper DB items 2026-01-05 21:07:53 +03:00
Rostislav Dugin
3ef51c4d68 FEATURE (databases): Imrove check for required permissions to backup, check for read-only user and extend DBs models tests 2026-01-05 21:07:53 +03:00
github-actions[bot]
e47e513460 Update CITATION.cff to v2.20.3 2026-01-04 21:28:24 +00:00
Rostislav Dugin
226a6c06e6 Merge pull request #216 from databasus/develop
FIX (readonly user): Improve complexity of readonly user passwords to…
2026-01-05 00:07:37 +03:00
Rostislav Dugin
615fd9d574 FIX (readonly user): Improve complexity of readonly user passwords to pass Google Cloud requirements 2026-01-05 00:06:24 +03:00
github-actions[bot]
e9fcf20cdf Update CITATION.cff to v2.20.2 2026-01-04 20:14:38 +00:00
Rostislav Dugin
7649f4acfd Merge pull request #214 from databasus/develop
FIX (databases): Add timeout for deletion in case of storage stuck
2026-01-04 22:54:13 +03:00
Rostislav Dugin
7e4c3bcc19 FIX (databases): Add timeout for deletion in case of storage stuck 2026-01-04 22:51:11 +03:00
Rostislav Dugin
f2aecc0427 Merge pull request #212 from databasus/develop
FIX (mariadb): Add events exclusion for MariaDB
2026-01-04 22:16:24 +03:00
Rostislav Dugin
3ce7da319f FIX (mariadb): Add events exclusion for MariaDB 2026-01-04 22:15:31 +03:00
github-actions[bot]
096098f660 Update CITATION.cff to v2.20.1 2026-01-04 18:13:01 +00:00
Rostislav Dugin
c3ba4a7c5a Merge pull request #209 from databasus/develop
FIX (backups): Escape password over connection check to allow whitesp…
2026-01-04 20:50:08 +03:00
Rostislav Dugin
52c0f53608 FIX (backups): Escape password over connection check to allow whitespaces 2026-01-04 20:49:22 +03:00
github-actions[bot]
a5095acad4 Update CITATION.cff to v2.20.0 2026-01-04 14:54:48 +00:00
78 changed files with 3515 additions and 545 deletions

View File

@@ -672,17 +672,6 @@ jobs:
echo EOF
} >> $GITHUB_OUTPUT
- name: Update CITATION.cff version
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add CITATION.cff
git commit -m "Update CITATION.cff to v${VERSION}" || true
git push || true
- name: Create GitHub Release
uses: actions/create-release@v1
env:

View File

@@ -6,14 +6,14 @@ repos:
hooks:
- id: frontend-format
name: Frontend Format (Prettier)
entry: powershell -Command "cd frontend; npm run format"
entry: bash -c "cd frontend && npm run format"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css|md)$
pass_filenames: false
- id: frontend-lint
name: Frontend Lint (ESLint)
entry: powershell -Command "cd frontend; npm run lint"
entry: bash -c "cd frontend && npm run lint"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx)$
pass_filenames: false
@@ -23,7 +23,7 @@ repos:
hooks:
- id: backend-format-and-lint
name: Backend Format & Lint (golangci-lint)
entry: powershell -Command "cd backend; golangci-lint fmt; golangci-lint run"
entry: bash -c "cd backend && golangci-lint fmt ./internal/... ./cmd/... && golangci-lint run ./internal/... ./cmd/..."
language: system
files: ^backend/.*\.go$
pass_filenames: false
pass_filenames: false

View File

@@ -32,5 +32,5 @@ keywords:
- mongodb
- mariadb
license: Apache-2.0
version: 2.19.2
date-released: "2026-01-02"
version: 2.21.0
date-released: "2026-01-05"

View File

@@ -2,7 +2,7 @@ run:
go run cmd/main.go
test:
go test -p=1 -count=1 -failfast -timeout 10m .\internal\...
go test -p=1 -count=1 -failfast -timeout 10m ./internal/...
lint:
golangci-lint fmt && golangci-lint run

View File

@@ -183,6 +183,7 @@ func setUpRoutes(r *gin.Engine) {
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
backups.GetBackupController().RegisterPublicRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
@@ -243,6 +244,10 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "audit log cleanup background service", func() {
audit_logs.GetAuditLogBackgroundService().Run()
})
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups.GetDownloadTokenBackgroundService().Run()
})
}
func runWithPanicLogging(log *slog.Logger, serviceName string, fn func()) {

View File

@@ -18,11 +18,17 @@ type BackupController struct {
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/backups", c.GetBackups)
router.POST("/backups", c.MakeBackup)
router.GET("/backups/:id/file", c.GetFile)
router.POST("/backups/:id/download-token", c.GenerateDownloadToken)
router.DELETE("/backups/:id", c.DeleteBackup)
router.POST("/backups/:id/cancel", c.CancelBackup)
}
// RegisterPublicRoutes registers routes that don't require Bearer authentication
// (they have their own authentication mechanisms like download tokens)
func (c *BackupController) RegisterPublicRoutes(router *gin.RouterGroup) {
router.GET("/backups/:id/file", c.GetFile)
}
// GetBackups
// @Summary Get backups for a database
// @Description Get paginated backups for the specified database
@@ -159,17 +165,16 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
ctx.Status(http.StatusNoContent)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup
// GenerateDownloadToken
// @Summary Generate short-lived download token
// @Description Generate a token for downloading a backup file (valid for 5 minutes)
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 200 {file} file
// @Success 200 {object} GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
// @Router /backups/{id}/download-token [post]
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
@@ -182,7 +187,56 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
fileReader, backup, database, err := c.backupService.GetBackupFile(user, id)
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
if token == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "download token is required"})
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
downloadToken, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
fileReader, backup, database, err := c.backupService.GetBackupFileWithoutAuth(
downloadToken.BackupID,
)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -195,6 +249,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
}
ctx.Header("Content-Type", "application/octet-stream")
ctx.Header(
"Content-Disposition",
@@ -203,9 +263,12 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
_, err = io.Copy(ctx.Writer, fileReader)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
fmt.Printf("Error streaming file: %v\n", err)
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
type MakeBackupRequest struct {

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,7 +16,9 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -87,7 +90,13 @@ func Test_GetBackups_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -179,7 +188,13 @@ func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -309,7 +324,13 @@ func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -378,7 +399,7 @@ func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
assert.True(t, found, "Audit log for backup deletion not found")
}
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
func Test_GenerateDownloadToken_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
@@ -387,28 +408,28 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
expectedStatusCode int
}{
{
name: "workspace viewer can download backup",
name: "workspace viewer can generate token",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can download backup",
name: "workspace member can generate token",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot download backup",
name: "non-member cannot generate token",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can download backup",
name: "global admin can generate token",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
@@ -433,7 +454,13 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
} else {
@@ -441,21 +468,244 @@ func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+testUserToken,
nil,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
if tt.expectSuccess {
var response GenerateDownloadTokenResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.NotEmpty(t, response.Token)
assert.NotEmpty(t, response.Filename)
assert.Equal(t, backup.ID, response.BackupID)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_DownloadBackup_WithValidToken_Success(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
// Verify response
contentDisposition := testResp.Headers.Get("Content-Disposition")
assert.Contains(t, contentDisposition, "attachment")
assert.Contains(t, contentDisposition, tokenResponse.Filename)
}
func Test_DownloadBackup_WithoutToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download without token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "download token is required")
}
func Test_DownloadBackup_WithInvalidToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Try to download with invalid token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), "invalid-token-xyz"),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_WithExpiredToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Get user for token generation
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
// Create an expired token directly in the database
expiredToken := createExpiredDownloadToken(backup.ID, user.ID)
// Try to download with expired token
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), expiredToken),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
// Verify audit log was NOT created for failed download
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup file downloaded") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.False(t, found, "Audit log should NOT be created for failed download with expired token")
}
func Test_DownloadBackup_TokenUsedOnce_CannotReuseToken(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token (first time - should succeed)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
// Try to download again with same token (should fail)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_WithDifferentBackupToken_Unauthorized(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database1 := createTestDatabase("Database 1", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config1, err := configService.GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
config1.IsBackupsEnabled = true
config1.StorageID = &storage.ID
config1.Storage = storage
_, err = configService.SaveBackupConfig(config1)
assert.NoError(t, err)
backup1 := createTestBackup(database1, owner)
database2 := createTestDatabase("Database 2", workspace.ID, owner.Token, router)
config2, err := configService.GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
config2.IsBackupsEnabled = true
config2.StorageID = &storage.ID
config2.Storage = storage
_, err = configService.SaveBackupConfig(config2)
assert.NoError(t, err)
backup2 := createTestBackup(database2, owner)
// Generate token for backup1
var tokenResponse GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup1.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Try to use backup1's token to download backup2 (should fail)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup2.ID.String(), tokenResponse.Token),
"",
http.StatusUnauthorized,
)
assert.Contains(t, string(testResp.Body), "invalid or expired download token")
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -463,11 +713,24 @@ func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), tokenResponse.Token),
"",
http.StatusOK,
)
@@ -542,11 +805,28 @@ func Test_DownloadBackup_ProperFilenameForPostgreSQL(t *testing.T) {
backup := createTestBackup(database, owner)
// Generate download token
var tokenResponse GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&tokenResponse,
)
// Download with token
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
tokenResponse.Token,
),
"",
http.StatusOK,
)
@@ -679,7 +959,13 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -687,9 +973,9 @@ func createTestDatabase(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -809,9 +1095,38 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(
context.Background(),
encryption.GetFieldEncryptor(),
logger,
backup.ID,
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}
func createExpiredDownloadToken(backupID, userID uuid.UUID) string {
tokenService := GetBackupService().downloadTokenService
token, err := tokenService.Generate(backupID, userID)
if err != nil {
panic(fmt.Sprintf("Failed to generate download token: %v", err))
}
// Manually update the token to be expired
repo := &download_token.DownloadTokenRepository{}
downloadToken, err := repo.FindByToken(token)
if err != nil || downloadToken == nil {
panic(fmt.Sprintf("Failed to find generated token: %v", err))
}
// Set expiration to 10 minutes ago
downloadToken.ExpiresAt = time.Now().UTC().Add(-10 * time.Minute)
if err := repo.Update(downloadToken); err != nil {
panic(fmt.Sprintf("Failed to update token expiration: %v", err))
}
return token
}

View File

@@ -4,6 +4,7 @@ import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -34,6 +35,7 @@ var backupService = &BackupService{
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
download_token.GetDownloadTokenService(),
}
var backupBackgroundService = &BackupBackgroundService{
@@ -69,3 +71,7 @@ func GetBackupController() *BackupController {
func GetBackupBackgroundService() *BackupBackgroundService {
return backupBackgroundService
}
func GetDownloadTokenBackgroundService() *download_token.DownloadTokenBackgroundService {
return download_token.GetDownloadTokenBackgroundService()
}

View File

@@ -0,0 +1,32 @@
package download_token
import (
"databasus-backend/internal/config"
"log/slog"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
}
func (s *DownloadTokenBackgroundService) Run() {
s.logger.Info("Starting download token cleanup background service")
if config.IsShouldShutdown() {
return
}
for {
if config.IsShouldShutdown() {
return
}
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
time.Sleep(1 * time.Minute)
}
}

View File

@@ -0,0 +1,25 @@
package download_token
import (
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
}
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
}
func GetDownloadTokenService() *DownloadTokenService {
return downloadTokenService
}
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}

View File

@@ -0,0 +1,21 @@
package download_token
import (
"time"
"github.com/google/uuid"
)
type DownloadToken struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey"`
Token string `json:"token" gorm:"column:token;uniqueIndex;not null"`
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;not null"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null"`
ExpiresAt time.Time `json:"expiresAt" gorm:"column:expires_at;not null"`
Used bool `json:"used" gorm:"column:used;not null;default:false"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at;not null"`
}
func (DownloadToken) TableName() string {
return "download_tokens"
}

View File

@@ -0,0 +1,60 @@
package download_token
import (
"crypto/rand"
"databasus-backend/internal/storage"
"encoding/base64"
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
type DownloadTokenRepository struct{}
func (r *DownloadTokenRepository) Create(token *DownloadToken) error {
if token.ID == uuid.Nil {
token.ID = uuid.New()
}
if token.CreatedAt.IsZero() {
token.CreatedAt = time.Now().UTC()
}
return storage.GetDb().Create(token).Error
}
func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, error) {
var downloadToken DownloadToken
err := storage.GetDb().
Where("token = ?", token).
First(&downloadToken).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &downloadToken, nil
}
func (r *DownloadTokenRepository) Update(token *DownloadToken) error {
return storage.GetDb().Save(token).Error
}
func (r *DownloadTokenRepository) DeleteExpired(before time.Time) error {
return storage.GetDb().
Where("expires_at < ?", before).
Delete(&DownloadToken{}).Error
}
func GenerateSecureToken() string {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
panic("failed to generate secure random token: " + err.Error())
}
return base64.URLEncoding.EncodeToString(b)
}

View File

@@ -0,0 +1,69 @@
package download_token
import (
"errors"
"log/slog"
"time"
"github.com/google/uuid"
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
token := GenerateSecureToken()
downloadToken := &DownloadToken{
Token: token,
BackupID: backupID,
UserID: userID,
ExpiresAt: time.Now().UTC().Add(5 * time.Minute),
Used: false,
}
if err := s.repository.Create(downloadToken); err != nil {
return "", err
}
s.logger.Info("Generated download token", "backupId", backupID, "userId", userID)
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, err
}
if dt == nil {
return nil, errors.New("invalid token")
}
if dt.Used {
return nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, errors.New("token expired")
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
return dt, nil
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {
return err
}
s.logger.Debug("Cleaned expired download tokens")
return nil
}

View File

@@ -3,6 +3,8 @@ package backups
import (
"databasus-backend/internal/features/backups/backups/encryption"
"io"
"github.com/google/uuid"
)
type GetBackupsRequest struct {
@@ -18,6 +20,12 @@ type GetBackupsResponse struct {
Offset int `json:"offset"`
}
type GenerateDownloadTokenResponse struct {
Token string `json:"token"`
Filename string `json:"filename"`
BackupID uuid.UUID `json:"backupId"`
}
type decryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser

View File

@@ -12,6 +12,7 @@ import (
"time"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/download_token"
"databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -44,6 +45,7 @@ type BackupService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
downloadTokenService *download_token.DownloadTokenService
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
@@ -683,3 +685,113 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
fileReader,
}, nil
}
func (s *BackupService) GenerateDownloadToken(
user *users_models.User,
backupID uuid.UUID,
) (*GenerateDownloadTokenResponse, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to download backup for this database")
}
token, err := s.downloadTokenService.Generate(backupID, user.ID)
if err != nil {
return nil, err
}
filename := s.generateBackupFilename(backup, database)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Download token generated for backup of database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return &GenerateDownloadTokenResponse{
Token: token,
Filename: filename,
BackupID: backupID,
}, nil
}
func (s *BackupService) ValidateDownloadToken(token string) (*download_token.DownloadToken, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
func (s *BackupService) GetBackupFileWithoutAuth(
backupID uuid.UUID,
) (io.ReadCloser, *Backup, *databases.Database, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, nil, nil, err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, nil, nil, err
}
reader, err := s.getBackupReader(backupID)
if err != nil {
return nil, nil, nil, err
}
return reader, backup, database, nil
}
func (s *BackupService) WriteAuditLogForDownload(
userID uuid.UUID,
backup *Backup,
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backup.ID.String(),
),
&userID,
database.WorkspaceID,
)
}
func (s *BackupService) generateBackupFilename(
backup *Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
safeName := sanitizeFilename(database.Name)
extension := s.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}
func (s *BackupService) getBackupExtension(dbType databases.DatabaseType) string {
switch dbType {
case databases.DatabaseTypeMysql, databases.DatabaseTypeMariadb:
return ".sql.zst"
case databases.DatabaseTypePostgres:
return ".dump"
case databases.DatabaseTypeMongodb:
return ".archive"
default:
return ".backup"
}
}

View File

@@ -65,6 +65,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
// Set up expectations
@@ -113,6 +114,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
backupService.MakeBackup(database.ID, true)
@@ -138,6 +140,7 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
nil,
}
// capture arguments

View File

@@ -14,13 +14,19 @@ import (
)
func CreateTestRouter() *gin.Engine {
return workspaces_testing.CreateTestRouter(
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
GetBackupController(),
)
// Register public routes (no auth required - token-based)
v1 := router.Group("/api/v1")
GetBackupController().RegisterPublicRoutes(v1)
return router
}
// WaitForBackupCompletion waits for a new backup to be created and completed (or failed)

View File

@@ -107,12 +107,17 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--user=" + mdb.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--quick",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, "--compress")
if mdb.IsHttps {

View File

@@ -105,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--user=" + my.Username,
"--single-transaction",
"--routines",
"--triggers",
"--events",
"--set-gtid-purged=OFF",
"--quick",
"--verbose",
}
if my.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if my.HasPrivilege("EVENT") {
args = append(args, "--events")
}
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
if my.IsHttps {

View File

@@ -135,7 +135,14 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, db.Postgresql.CpuCount, pgBin); err != nil {
if err := uc.setupPgEnvironment(
cmd,
pgpassFile,
db.Postgresql.IsHttps,
password,
db.Postgresql.CpuCount,
pgBin,
); err != nil {
return nil, err
}

View File

@@ -2,13 +2,16 @@ package backups_config
import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
@@ -94,7 +97,13 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -241,7 +250,13 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1434,7 +1449,13 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
@@ -1442,9 +1463,9 @@ func createTestDatabaseViaAPI(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -1459,7 +1480,9 @@ func createTestDatabaseViaAPI(
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database

View File

@@ -392,13 +392,13 @@ func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
return
}
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
isReadOnly, privileges, err := c.databaseService.IsUserReadOnly(user, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly, Privileges: privileges})
}
// CreateReadOnlyUser

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"testing"
@@ -11,6 +12,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
@@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine {
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -84,24 +151,21 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
var response Database
@@ -132,20 +196,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: getTestPostgresConfig(),
}
testResp := test_utils.MakePostRequest(
@@ -214,7 +269,13 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -316,7 +377,13 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -381,7 +448,13 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
testUser = admin.Token
} else if tt.userRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.userRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.userRole,
owner.Token,
router,
)
testUser = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -605,7 +678,13 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -737,7 +816,13 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *Database {
testDbName := "test_db"
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
@@ -745,9 +830,9 @@ func createTestDatabaseViaAPI(
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
@@ -780,21 +865,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
pgConfig := getTestPostgresConfig()
plainPassword := "testpassword"
pgConfig.Password = plainPassword
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
var createdDatabase Database
@@ -854,38 +932,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "PostgreSQL Database",
databaseType: DatabaseTypePostgres,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
pgConfig := getTestPostgresConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "original-password-secret",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
pgConfig := getTestPostgresConfig()
pgConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion17,
Host: "updated-host",
Port: 5433,
Username: "updated_user",
Password: "",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: pgConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -895,7 +958,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)
@@ -905,36 +968,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MariaDB Database",
databaseType: DatabaseTypeMariadb,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
mariaConfig := getTestMariadbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "original-password-secret",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
mariaConfig := getTestMariadbConfig()
mariaConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MariaDB Database",
Type: DatabaseTypeMariadb,
Mariadb: &mariadb.MariadbDatabase{
Version: tools.MariadbVersion114,
Host: "updated-host",
Port: 3307,
Username: "updated_user",
Password: "",
Database: &testDbName,
},
Mariadb: mariaConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -944,7 +994,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "testpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mariadb.Password)
@@ -954,40 +1004,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
name: "MongoDB Database",
databaseType: DatabaseTypeMongodb,
createDatabase: func(workspaceID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
return &Database{
WorkspaceID: &workspaceID,
Name: "Test MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: 27017,
Username: "root",
Password: "original-password-secret",
Database: "test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
mongoConfig := getTestMongodbConfig()
mongoConfig.Password = ""
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated MongoDB Database",
Type: DatabaseTypeMongodb,
Mongodb: &mongodb.MongodbDatabase{
Version: tools.MongodbVersion8,
Host: "updated-host",
Port: 27018,
Username: "updated_user",
Password: "",
Database: "updated_test_db",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
},
Mongodb: mongoConfig,
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
@@ -997,7 +1030,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
assert.Equal(t, "rootpassword", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Mongodb.Password)

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
@@ -23,12 +24,13 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -94,6 +96,16 @@ func (m *MariadbDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -111,6 +123,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -131,15 +144,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MariadbDatabase) PopulateVersionIfEmpty(
func (m *MariadbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMariadbVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MariadbDatabase) PopulateVersion(
@@ -175,8 +221,8 @@ func (m *MariadbDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -185,17 +231,17 @@ func (m *MariadbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -205,33 +251,44 @@ func (m *MariadbDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "DELETE HISTORY", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MariadbDatabase) CreateReadOnlyUser(
@@ -261,7 +318,7 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
for attempt := range maxRetries {
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -326,10 +383,23 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MariadbDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for _, p := range strings.Split(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MariadbDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
tlsConfig = "skip-verify"
}
return fmt.Sprintf(
@@ -420,6 +490,99 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -18,6 +18,171 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, limitedUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(container.DB, backupUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -49,13 +214,56 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mariadbModel := createMariadbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MariadbDatabase{
Version: mariadbModel.Version,
Host: mariadbModel.Host,
Port: mariadbModel.Port,
Username: username,
Password: password,
Database: mariadbModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
dropUserSafe(container.DB, username)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -127,9 +335,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
@@ -382,6 +596,5 @@ func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
}
func dropUserSafe(db *sqlx.DB, username string) {
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
}

View File

@@ -5,7 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"net/url"
"regexp"
"strings"
"time"
"databasus-backend/internal/util/encryption"
@@ -95,6 +97,16 @@ func (m *MongodbDatabase) TestConnection(
}
m.Version = detectedVersion
if err := checkBackupPermissions(
ctx,
client,
m.Username,
m.Database,
m.AuthDatabase,
); err != nil {
return err
}
return nil
}
@@ -134,14 +146,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields(
return nil
}
func (m *MongodbDatabase) PopulateVersionIfEmpty(
func (m *MongodbDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
}
@@ -185,10 +194,10 @@ func (m *MongodbDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
uri := m.buildConnectionURI(password)
@@ -196,7 +205,7 @@ func (m *MongodbDatabase) IsUserReadOnly(
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if disconnectErr := client.Disconnect(ctx); disconnectErr != nil {
@@ -218,44 +227,153 @@ func (m *MongodbDatabase) IsUserReadOnly(
}},
}).Decode(&result)
if err != nil {
return false, fmt.Errorf("failed to get user info: %w", err)
return false, nil, fmt.Errorf("failed to get user info: %w", err)
}
writeRoles := []string{
"readWrite", "readWriteAnyDatabase", "dbAdmin", "dbAdminAnyDatabase",
"userAdmin", "userAdminAnyDatabase", "clusterAdmin", "root",
"dbOwner", "backup", "restore",
writeRoles := map[string]bool{
"readWrite": true,
"readWriteAnyDatabase": true,
"dbAdmin": true,
"dbAdminAnyDatabase": true,
"userAdmin": true,
"userAdminAnyDatabase": true,
"clusterAdmin": true,
"clusterManager": true,
"hostManager": true,
"root": true,
"dbOwner": true,
"restore": true,
"__system": true,
}
// Roles that are read-only for our backup purposes
// The "backup" role has insert/update on mms.backup collection but is needed for mongodump
readOnlyRoles := map[string]bool{
"read": true,
"backup": true,
}
writeActions := map[string]bool{
"insert": true,
"update": true,
"remove": true,
"createCollection": true,
"dropCollection": true,
"createIndex": true,
"dropIndex": true,
"convertToCapped": true,
"dropDatabase": true,
"renameCollection": true,
"createUser": true,
"dropUser": true,
"updateUser": true,
"grantRole": true,
"revokeRole": true,
"dropRole": true,
"createRole": true,
"updateRole": true,
"enableSharding": true,
"shardCollection": true,
"addShard": true,
"removeShard": true,
"shutdown": true,
"replSetReconfig": true,
"replSetStateChange": true,
}
var detectedRoles []string
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return true, nil
return true, detectedRoles, nil
}
user, ok := users[0].(bson.M)
if !ok {
return true, nil
return true, detectedRoles, nil
}
roles, ok := user["roles"].(bson.A)
if !ok {
return true, nil
return true, detectedRoles, nil
}
// Collect all role names and check for write roles
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
for _, writeRole := range writeRoles {
if roleName == writeRole {
return false, nil
if roleName != "" {
detectedRoles = append(detectedRoles, roleName)
}
}
// Check if any detected role is a write role
for _, roleName := range detectedRoles {
if writeRoles[roleName] {
return false, detectedRoles, nil
}
}
// If all roles are known read-only roles (read, backup), skip inherited privilege check
allRolesReadOnly := true
for _, roleName := range detectedRoles {
if !readOnlyRoles[roleName] {
allRolesReadOnly = false
break
}
}
if allRolesReadOnly && len(detectedRoles) > 0 {
return true, detectedRoles, nil
}
// Check inherited privileges for custom roles
var privResult bson.M
err = adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: m.Username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&privResult)
if err != nil {
return false, nil, fmt.Errorf("failed to get user privileges: %w", err)
}
privUsers, ok := privResult["users"].(bson.A)
if !ok || len(privUsers) == 0 {
return true, detectedRoles, nil
}
privUser, ok := privUsers[0].(bson.M)
if !ok {
return true, detectedRoles, nil
}
// Check inheritedPrivileges for write actions
inheritedPrivileges, ok := privUser["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && writeActions[actionStr] {
return false, detectedRoles, nil
}
}
}
}
return true, nil
return true, detectedRoles, nil
}
func (m *MongodbDatabase) CreateReadOnlyUser(
@@ -290,7 +408,7 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
adminDB := client.Database(authDB)
err = adminDB.RunCommand(ctx, bson.D{
@@ -332,20 +450,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
m.Database,
authDB,
tlsOption,
tlsParams,
)
}
@@ -356,19 +474,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB = "admin"
}
tlsOption := "false"
tlsParams := ""
if m.IsHttps {
tlsOption = "true"
tlsParams = "&tls=true&tlsInsecure=true"
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000",
m.Username,
password,
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
authDB,
tlsOption,
tlsParams,
)
}
@@ -413,6 +531,128 @@ func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.Mong
}
}
// checkBackupPermissions verifies the user has sufficient privileges for mongodump backup.
// Required: 'read' role on target database OR 'backup' role on admin OR 'readAnyDatabase' role.
func checkBackupPermissions(
ctx context.Context,
client *mongo.Client,
username, database, authDatabase string,
) error {
authDB := authDatabase
if authDB == "" {
authDB = "admin"
}
adminDB := client.Database(authDB)
var result bson.M
err := adminDB.RunCommand(ctx, bson.D{
{Key: "usersInfo", Value: bson.D{
{Key: "user", Value: username},
{Key: "db", Value: authDB},
}},
{Key: "showPrivileges", Value: true},
}).Decode(&result)
if err != nil {
return fmt.Errorf("failed to get user info: %w", err)
}
users, ok := result["users"].(bson.A)
if !ok || len(users) == 0 {
return errors.New("insufficient permissions for backup. User not found")
}
user, ok := users[0].(bson.M)
if !ok {
return errors.New("insufficient permissions for backup. Could not parse user info")
}
// Check roles for backup permissions
roles, ok := user["roles"].(bson.A)
if !ok {
return errors.New("insufficient permissions for backup. No roles assigned")
}
backupRoles := map[string]bool{
"backup": true,
"root": true,
"readAnyDatabase": true,
"dbOwner": true,
"__system": true,
"clusterAdmin": true,
"readWriteAnyDatabase": true,
}
var userRoles []string
hasBackupRole := false
hasReadOnTargetDB := false
for _, roleDoc := range roles {
role, ok := roleDoc.(bson.M)
if !ok {
continue
}
roleName, _ := role["role"].(string)
roleDB, _ := role["db"].(string)
if roleName != "" {
userRoles = append(userRoles, roleName)
}
if backupRoles[roleName] {
hasBackupRole = true
}
if roleName == "read" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
if roleName == "readWrite" && (roleDB == database || roleDB == "") {
hasReadOnTargetDB = true
}
}
if hasBackupRole || hasReadOnTargetDB {
return nil
}
// Check inherited privileges for 'find' action on target database
inheritedPrivileges, ok := user["inheritedPrivileges"].(bson.A)
if ok {
for _, privDoc := range inheritedPrivileges {
priv, ok := privDoc.(bson.M)
if !ok {
continue
}
resource, ok := priv["resource"].(bson.M)
if !ok {
continue
}
resourceDB, _ := resource["db"].(string)
resourceCluster, _ := resource["cluster"].(bool)
isTargetDB := resourceDB == database || resourceDB == "" || resourceCluster
actions, ok := priv["actions"].(bson.A)
if !ok {
continue
}
for _, action := range actions {
actionStr, ok := action.(string)
if ok && actionStr == "find" && isTargetDB {
return nil
}
}
}
}
return fmt.Errorf(
"insufficient permissions for backup. Current roles: %s. Required: 'read' role on database '%s' OR 'backup' role on admin OR 'readAnyDatabase' role",
strings.Join(userRoles, ", "),
database,
)
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -4,6 +4,7 @@ import (
"context"
"fmt"
"log/slog"
"net/url"
"os"
"strconv"
"strings"
@@ -19,6 +20,138 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("permission_test").Drop(ctx)
_, err := db.Collection("permission_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: limitedUsername},
{Key: "pwd", Value: limitedPassword},
{Key: "roles", Value: bson.A{}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MongodbVersion
port string
}{
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMongodbContainer(t, tc.port, tc.version)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("backup_test").Drop(ctx)
_, err := db.Collection("backup_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
adminDB := container.Client.Database(container.AuthDatabase)
err = adminDB.RunCommand(ctx, bson.D{
{Key: "createUser", Value: backupUsername},
{Key: "pwd", Value: backupPassword},
{Key: "roles", Value: bson.A{
bson.D{
{Key: "role", Value: "read"},
{Key: "db", Value: container.Database},
},
}},
}).Err()
assert.NoError(t, err)
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mongodbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -46,13 +179,52 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, roles, "Root user should have roles")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
defer container.Client.Disconnect(context.Background())
ctx := context.Background()
db := container.Client.Database(container.Database)
_ = db.Collection("readonly_check_test").Drop(ctx)
_, err := db.Collection("readonly_check_test").InsertOne(ctx, bson.M{"data": "test1"})
assert.NoError(t, err)
mongodbModel := createMongodbModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
username, password, err := mongodbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MongodbDatabase{
Version: mongodbModel.Version,
Host: mongodbModel.Host,
Port: mongodbModel.Port,
Username: username,
Password: password,
Database: mongodbModel.Database,
AuthDatabase: mongodbModel.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, roles, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.NotEmpty(t, roles, "Read-only user should have roles (read, backup)")
dropUserSafe(container.Client, username, container.AuthDatabase)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -271,6 +443,7 @@ func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
CpuCount: 1,
}
}
@@ -281,7 +454,8 @@ func connectWithCredentials(
) *mongo.Client {
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, container.Host, container.Port,
url.QueryEscape(username), url.QueryEscape(password),
container.Host, container.Port,
container.Database, container.AuthDatabase,
)

View File

@@ -7,6 +7,8 @@ import (
"fmt"
"log/slog"
"regexp"
"sort"
"strings"
"time"
"databasus-backend/internal/util/encryption"
@@ -22,12 +24,13 @@ type MysqlDatabase struct {
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MysqlDatabase) TableName() string {
@@ -93,6 +96,16 @@ func (m *MysqlDatabase) TestConnection(
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
if err := checkBackupPermissions(m.Privileges); err != nil {
return err
}
return nil
}
@@ -110,6 +123,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.Privileges = incoming.Privileges
if incoming.Password != "" {
m.Password = incoming.Password
@@ -130,15 +144,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields(
return nil
}
func (m *MysqlDatabase) PopulateVersionIfEmpty(
func (m *MysqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
if m.Database == nil || *m.Database == "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMysqlVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
privileges, err := detectPrivileges(ctx, db, *m.Database)
if err != nil {
return err
}
m.Privileges = privileges
return nil
}
func (m *MysqlDatabase) PopulateVersion(
@@ -174,8 +221,8 @@ func (m *MysqlDatabase) PopulateVersion(
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
@@ -184,17 +231,17 @@ func (m *MysqlDatabase) IsUserReadOnly(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
@@ -204,33 +251,45 @@ func (m *MysqlDatabase) IsUserReadOnly(
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
return false, nil, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
"ALTER ROUTINE", "CREATE USER",
"CREATE TABLESPACE", "REFERENCES",
}
detectedPrivileges := make(map[string]bool)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
detectedPrivileges[priv] = true
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
return false, nil, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
privileges := make([]string, 0, len(detectedPrivileges))
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
func (m *MysqlDatabase) CreateReadOnlyUser(
@@ -259,7 +318,7 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
maxRetries := 3
for attempt := range maxRetries {
newUsername := fmt.Sprintf("databasus-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
@@ -325,10 +384,23 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MysqlDatabase) HasPrivilege(priv string) bool {
return HasPrivilege(m.Privileges, priv)
}
func HasPrivilege(privileges, priv string) bool {
for p := range strings.SplitSeq(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}
}
return false
}
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
tlsConfig = "skip-verify"
}
return fmt.Sprintf(
@@ -388,6 +460,99 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion {
}
}
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return "", fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
backupPrivileges := []string{
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
}
detectedPrivileges := make(map[string]bool)
hasProcess := false
hasAllPrivileges := false
escapedDB := strings.ReplaceAll(database, "_", "\\_")
dbPattern := regexp.MustCompile(
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return "", fmt.Errorf("failed to scan grant: %w", err)
}
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
hasAllPrivileges = true
}
}
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
for _, priv := range backupPrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
detectedPrivileges[priv] = true
}
}
}
if globalPattern.MatchString(grant) &&
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
hasProcess = true
}
}
if err := rows.Err(); err != nil {
return "", fmt.Errorf("error iterating grants: %w", err)
}
if hasAllPrivileges {
for _, priv := range backupPrivileges {
detectedPrivileges[priv] = true
}
hasProcess = true
}
privileges := make([]string, 0, len(detectedPrivileges)+1)
for priv := range detectedPrivileges {
privileges = append(privileges, priv)
}
if hasProcess {
privileges = append(privileges, "PROCESS")
}
sort.Strings(privileges)
return strings.Join(privileges, ","), nil
}
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
func checkBackupPermissions(privileges string) error {
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
var missingPrivileges []string
for _, priv := range requiredPrivileges {
if !HasPrivilege(privileges, priv) {
missingPrivileges = append(missingPrivileges, priv)
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
strings.Join(missingPrivileges, ", "),
)
}
return nil
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -18,6 +18,165 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE permission_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
container.Database,
limitedUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", limitedUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
assert.Contains(t, err.Error(), "insufficient permissions")
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE backup_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", backupUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -42,13 +201,57 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Root user should not be read-only")
assert.NotEmpty(t, privileges, "Root user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
defer container.DB.Close()
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
assert.NoError(t, err)
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)`)
assert.NoError(t, err)
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
assert.NoError(t, err)
mysqlModel := createMysqlModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &MysqlDatabase{
Version: mysqlModel.Version,
Host: mysqlModel.Host,
Port: mysqlModel.Port,
Username: username,
Password: password,
Database: mysqlModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -109,9 +312,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",

View File

@@ -137,16 +137,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
return nil
}
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
// PopulateDbData detects and sets the PostgreSQL version.
// This should be called before encrypting sensitive fields.
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
func (p *PostgresqlDatabase) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.Version != "" {
return nil
}
return p.PopulateVersion(logger, encryptor, databaseID)
}
@@ -192,29 +189,33 @@ func (p *PostgresqlDatabase) PopulateVersion(
// IsUserReadOnly checks if the database user has read-only privileges.
//
// This method performs a comprehensive security check by examining:
// - Role-level attributes (superuser, createrole, createdb)
// - Role-level attributes (superuser, createrole, createdb, bypassrls, replication)
// - Database-level privileges (CREATE, TEMP)
// - Schema-level privileges (CREATE on any non-system schema)
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
// - Function-level privileges (EXECUTE on SECURITY DEFINER functions)
//
// A user is considered read-only only if they have ZERO write privileges
// across all three levels. This ensures the database user follows the
// across all levels. This ensures the database user follows the
// principle of least privilege for backup operations.
//
// Returns: (isReadOnly, detectedPrivileges, error)
func (p *PostgresqlDatabase) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
) (bool, []string, error) {
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
@@ -222,22 +223,38 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
}
}()
var privileges []string
// LEVEL 1: Check role-level attributes
var isSuperuser, canCreateRole, canCreateDB bool
var isSuperuser, canCreateRole, canCreateDB, canBypassRLS, canReplication bool
err = conn.QueryRow(ctx, `
SELECT
rolsuper,
rolcreaterole,
rolcreatedb
rolcreatedb,
rolbypassrls,
rolreplication
FROM pg_roles
WHERE rolname = current_user
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB, &canBypassRLS, &canReplication)
if err != nil {
return false, fmt.Errorf("failed to check role attributes: %w", err)
return false, nil, fmt.Errorf("failed to check role attributes: %w", err)
}
if isSuperuser || canCreateRole || canCreateDB {
return false, nil
if isSuperuser {
privileges = append(privileges, "SUPERUSER")
}
if canCreateRole {
privileges = append(privileges, "CREATEROLE")
}
if canCreateDB {
privileges = append(privileges, "CREATEDB")
}
if canBypassRLS {
privileges = append(privileges, "BYPASSRLS")
}
if canReplication {
privileges = append(privileges, "REPLICATION")
}
// LEVEL 2: Check database-level privileges
@@ -248,46 +265,34 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
`).Scan(&canCreate, &canTemp)
if err != nil {
return false, fmt.Errorf("failed to check database privileges: %w", err)
return false, nil, fmt.Errorf("failed to check database privileges: %w", err)
}
if canCreate || canTemp {
return false, nil
if canCreate {
privileges = append(privileges, "CREATE (database)")
}
if canTemp {
privileges = append(privileges, "TEMP")
}
// LEVEL 2.5: Check schema-level CREATE privileges
schemaRows, err := conn.Query(ctx, `
SELECT DISTINCT nspname
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
`)
var hasSchemaCreate bool
err = conn.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
)
`).Scan(&hasSchemaCreate)
if err != nil {
return false, fmt.Errorf("failed to check schema privileges: %w", err)
return false, nil, fmt.Errorf("failed to check schema privileges: %w", err)
}
defer schemaRows.Close()
// If user has CREATE privilege on any schema, they're not read-only
if schemaRows.Next() {
return false, nil
}
if err := schemaRows.Err(); err != nil {
return false, fmt.Errorf("error iterating schema privileges: %w", err)
if hasSchemaCreate {
privileges = append(privileges, "CREATE (schema)")
}
// LEVEL 3: Check table-level write permissions
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, fmt.Errorf("failed to check table privileges: %w", err)
}
defer rows.Close()
writePrivileges := map[string]bool{
"INSERT": true,
"UPDATE": true,
@@ -297,22 +302,56 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
"TRIGGER": true,
}
var tablePrivileges []string
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, nil, fmt.Errorf("failed to check table privileges: %w", err)
}
for rows.Next() {
var privilege string
if err := rows.Scan(&privilege); err != nil {
return false, fmt.Errorf("failed to scan privilege: %w", err)
}
if writePrivileges[privilege] {
return false, nil
rows.Close()
return false, nil, fmt.Errorf("failed to scan privilege: %w", err)
}
tablePrivileges = append(tablePrivileges, privilege)
}
rows.Close()
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating privileges: %w", err)
return false, nil, fmt.Errorf("error iterating privileges: %w", err)
}
return true, nil
for _, privilege := range tablePrivileges {
if writePrivileges[privilege] {
privileges = append(privileges, privilege)
}
}
// LEVEL 4: Check for EXECUTE privilege on functions that are SECURITY DEFINER
var funcCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_proc p
JOIN pg_namespace n ON p.pronamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
AND p.prosecdef = true
AND has_function_privilege(current_user, p.oid, 'EXECUTE')
`).Scan(&funcCount)
if err != nil {
return false, nil, fmt.Errorf("failed to check function privileges: %w", err)
}
if funcCount > 0 {
privileges = append(privileges, "EXECUTE (SECURITY DEFINER)")
}
isReadOnly := len(privileges) == 0
return isReadOnly, privileges, nil
}
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
@@ -383,7 +422,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
newPassword := uuid.New().String()
newPassword := encryption.GenerateComplexPassword()
tx, err := conn.Begin(ctx)
if err != nil {
@@ -631,13 +670,9 @@ func testSingleDatabaseConnection(
}
postgresDb.Version = detectedVersion
// Test if we can perform basic operations (like pg_dump would need)
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
return fmt.Errorf(
"basic operations test failed for database '%s': %w",
*postgresDb.Database,
err,
)
// Verify user has sufficient permissions for backup operations
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
return err
}
return nil
@@ -670,18 +705,73 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
}
}
// testBasicOperations tests basic operations that backup tools need
func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) error {
var hasCreatePriv bool
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
var missingPrivileges []string
// Check CONNECT privilege on database
var hasConnect bool
err := conn.QueryRow(ctx, "SELECT has_database_privilege(current_user, current_database(), 'CONNECT')").
Scan(&hasCreatePriv)
Scan(&hasConnect)
if err != nil {
return fmt.Errorf("cannot check database privileges: %w", err)
}
if !hasConnect {
missingPrivileges = append(missingPrivileges, "CONNECT on database")
}
if !hasCreatePriv {
return fmt.Errorf("user does not have CONNECT privilege on database '%s'", dbName)
// Check USAGE privilege on at least one non-system schema
var schemaCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname NOT LIKE 'pg_temp_%'
AND n.nspname NOT LIKE 'pg_toast_temp_%'
`).Scan(&schemaCount)
if err != nil {
return fmt.Errorf("cannot check schema privileges: %w", err)
}
if schemaCount == 0 {
missingPrivileges = append(missingPrivileges, "USAGE on at least one schema")
}
// Check SELECT privilege on at least one table (if tables exist)
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
var tableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
`).Scan(&tableCount)
if err != nil {
return fmt.Errorf("cannot check table count: %w", err)
}
if tableCount > 0 {
// Check if user has SELECT on at least one of these tables
var selectableTableCount int
err = conn.QueryRow(ctx, `
SELECT COUNT(*)
FROM pg_catalog.pg_tables t
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
`).Scan(&selectableTableCount)
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")
}
}
if len(missingPrivileges) > 0 {
return fmt.Errorf(
"insufficient permissions for backup. Missing: %s. Required: CONNECT on database, USAGE on schemas, SELECT on tables",
strings.Join(missingPrivileges, ", "),
)
}
return nil
@@ -695,16 +785,22 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password s
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
"host=%s port=%d user=%s password='%s' dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
p.Host,
p.Port,
p.Username,
password,
escapeConnectionStringValue(password),
dbName,
sslMode,
)
}
func escapeConnectionStringValue(value string) string {
value = strings.ReplaceAll(value, `\`, `\\`)
value = strings.ReplaceAll(value, `'`, `\'`)
return value
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,

View File

@@ -19,6 +19,230 @@ import (
"databasus-backend/internal/util/tools"
)
func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
passwordWithSpaces := "test password with spaces"
usernameWithSpaces := fmt.Sprintf("testuser_spaces_%s", uuid.New().String()[:8])
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS password_test CASCADE;
CREATE TABLE password_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO password_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
usernameWithSpaces,
passwordWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
usernameWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
usernameWithSpaces,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, usernameWithSpaces))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum("16"),
Host: container.Host,
Port: container.Port,
Username: usernameWithSpaces,
Password: passwordWithSpaces,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
}
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS permission_test CASCADE;
CREATE TABLE permission_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO permission_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
limitedUsername := fmt.Sprintf("limited_user_%s", uuid.New().String()[:8])
limitedPassword := "limitedpassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
limitedUsername,
limitedPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedUsername,
Password: limitedPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.Error(t, err)
if err != nil {
assert.Contains(t, err.Error(), "insufficient permissions")
}
})
}
}
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS backup_test CASCADE;
CREATE TABLE backup_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO backup_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
backupUsername := fmt.Sprintf("backup_user_%s", uuid.New().String()[:8])
backupPassword := "backuppassword123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
backupUsername,
backupPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
backupUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, backupUsername))
}()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: backupUsername,
Password: backupPassword,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = pgModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
})
}
}
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -45,13 +269,60 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Admin user should not be read-only")
assert.NotEmpty(t, privileges, "Admin user should have privileges")
})
}
}
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_check_test CASCADE;
CREATE TABLE readonly_check_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_check_test (data) VALUES ('test1');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
CpuCount: 1,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Read-only user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -106,9 +377,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
@@ -143,7 +420,6 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -187,7 +463,6 @@ func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -237,7 +512,6 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "data_b", dataB)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
@@ -486,6 +760,7 @@ func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
Password: container.Password,
Database: &container.Database,
IsHttps: false,
CpuCount: 1,
}
}

View File

@@ -6,5 +6,6 @@ type CreateReadOnlyUserResponse struct {
}
type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
IsReadOnly bool `json:"isReadOnly"`
Privileges []string `json:"privileges"`
}

View File

@@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
return nil
}
func (d *Database) PopulateVersionIfEmpty(
func (d *Database) PopulateDbData(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
if d.Postgresql != nil {
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Postgresql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mysql != nil {
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mysql.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mariadb != nil {
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mariadb.PopulateDbData(logger, encryptor, d.ID)
}
if d.Mongodb != nil {
return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
return d.Mongodb.PopulateDbData(logger, encryptor, d.ID)
}
return nil
}

View File

@@ -82,8 +82,8 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -149,8 +149,8 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
@@ -594,17 +594,17 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
func (s *DatabaseService) IsUserReadOnly(
user *users_models.User,
database *Database,
) (bool, error) {
) (bool, []string, error) {
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return false, err
return false, nil, err
}
if existingDatabase.WorkspaceID == nil {
return false, errors.New("cannot check user for database without workspace")
return false, nil, errors.New("cannot check user for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
@@ -612,31 +612,34 @@ func (s *DatabaseService) IsUserReadOnly(
user,
)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this database")
return false, nil, errors.New("insufficient permissions to access this database")
}
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
return false, errors.New("database does not belong to this workspace")
return false, nil, errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return false, err
return false, nil, err
}
usingDatabase = existingDatabase
} else {
if database.WorkspaceID != nil {
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return false, err
return false, nil, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this workspace")
return false, nil, errors.New("insufficient permissions to access this workspace")
}
}
@@ -676,7 +679,7 @@ func (s *DatabaseService) IsUserReadOnly(
usingDatabase.ID,
)
default:
return false, errors.New("read-only check not supported for this database type")
return false, nil, errors.New("read-only check not supported for this database type")
}
}

View File

@@ -1,6 +1,12 @@
package databases
import (
"fmt"
"strconv"
"databasus-backend/internal/config"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
@@ -9,6 +15,71 @@ import (
"github.com/google/uuid"
)
func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func GetTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func GetTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func CreateTestDatabase(
workspaceID uuid.UUID,
storage *storages.Storage,
@@ -18,16 +89,7 @@ func CreateTestDatabase(
WorkspaceID: &workspaceID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: 1,
},
Postgresql: GetTestPostgresConfig(),
Notifiers: []notifiers.Notifier{
*notifier,
},

View File

@@ -12,13 +12,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -111,7 +109,13 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -205,20 +209,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -10,13 +10,11 @@ import (
"github.com/stretchr/testify/assert"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/postgresql"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
@@ -90,7 +88,13 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -228,7 +232,13 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -293,20 +303,11 @@ func createTestDatabaseViaAPI(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(

View File

@@ -263,7 +263,12 @@ func (c *NotifierController) TransferNotifierToWorkspace(ctx *gin.Context) {
return
}
if err := c.notifierService.TransferNotifierToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if err := c.notifierService.TransferNotifierToWorkspace(
user,
id,
request.TargetWorkspaceID,
nil,
); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -1050,8 +1050,20 @@ func Test_TransferNotifier_PermissionsEnforced(t *testing.T) {
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(
sourceWorkspace,
testUser,
*tt.sourceRole,
sourceOwner.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
targetWorkspace,
testUser,
*tt.targetRole,
targetOwner.Token,
router,
)
testUserToken = testUser.Token
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
"strconv"
"strings"
"testing"
"time"
@@ -15,6 +16,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
@@ -288,7 +290,12 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
},
}
} else {
mysqlDB := createTestMySQLDatabase("Test MySQL DB", workspace.ID, owner.Token, router)
mysqlDB := createTestMySQLDatabase(
"Test MySQL DB",
workspace.ID,
owner.Token,
router,
)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
@@ -390,20 +397,11 @@ func createTestDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
CpuCount: 1,
},
Postgresql: databases.GetTestPostgresConfig(),
}
w := workspaces_testing.MakeAPIRequest(
@@ -434,7 +432,18 @@ func createTestMySQLDatabase(
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
env := config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MYSQL_80_PORT: %v", err))
}
testDbName := "testdb"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
@@ -442,9 +451,9 @@ func createTestMySQLDatabase(
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "password",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
},
}
@@ -526,7 +535,13 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(context.Background(), fieldEncryptor, logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(
context.Background(),
fieldEncryptor,
logger,
backup.ID,
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -229,8 +229,8 @@ func (s *RestoreService) RestoreBackup(
Mongodb: requestDTO.MongodbDatabase,
}
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
isExcludeExtensions := false

View File

@@ -131,7 +131,6 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
"--if-exists",
"--no-owner",
"--no-acl",
"--no-comments",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)

View File

@@ -263,7 +263,12 @@ func (c *StorageController) TransferStorageToWorkspace(ctx *gin.Context) {
return
}
if err := c.storageService.TransferStorageToWorkspace(user, id, request.TargetWorkspaceID, nil); err != nil {
if err := c.storageService.TransferStorageToWorkspace(
user,
id,
request.TargetWorkspaceID,
nil,
); err != nil {
if errors.Is(err, ErrInsufficientPermissionsInSourceWorkspace) ||
errors.Is(err, ErrInsufficientPermissionsInTargetWorkspace) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -1071,8 +1071,20 @@ func Test_TransferStorage_PermissionsEnforced(t *testing.T) {
testUserToken = admin.Token
} else if tt.sourceRole != nil {
testUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(sourceWorkspace, testUser, *tt.sourceRole, sourceOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(targetWorkspace, testUser, *tt.targetRole, targetOwner.Token, router)
workspaces_testing.AddMemberToWorkspace(
sourceWorkspace,
testUser,
*tt.sourceRole,
sourceOwner.Token,
router,
)
workspaces_testing.AddMemberToWorkspace(
targetWorkspace,
testUser,
*tt.targetRole,
targetOwner.Token,
router,
)
testUserToken = testUser.Token
}

View File

@@ -26,6 +26,7 @@ const (
azureResponseTimeout = 30 * time.Second
azureIdleConnTimeout = 90 * time.Second
azureTLSHandshakeTimeout = 30 * time.Second
azureDeleteTimeout = 30 * time.Second
// Chunk size for block blob uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -186,8 +187,11 @@ func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileI
blobName := s.buildBlobName(fileID.String())
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
defer cancel()
_, err = client.DeleteBlob(
context.TODO(),
ctx,
s.ContainerName,
blobName,
nil,

View File

@@ -18,6 +18,7 @@ import (
const (
ftpConnectTimeout = 30 * time.Second
ftpTestConnectTimeout = 10 * time.Second
ftpDeleteTimeout = 30 * time.Second
ftpChunkSize = 16 * 1024 * 1024
)
@@ -134,7 +135,10 @@ func (f *FTPStorage) GetFile(
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
conn, err := f.connect(encryptor, ftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
defer cancel()
conn, err := f.connectWithContext(ctx, encryptor, ftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to FTP: %w", err)
}

View File

@@ -27,6 +27,7 @@ const (
gdResponseTimeout = 30 * time.Second
gdIdleConnTimeout = 90 * time.Second
gdTLSHandshakeTimeout = 30 * time.Second
gdDeleteTimeout = 30 * time.Second
// Chunk size for Google Drive resumable uploads - 16MB provides good balance
// between memory usage and upload efficiency. Google Drive requires chunks
@@ -185,7 +186,9 @@ func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
defer cancel()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {

View File

@@ -18,6 +18,8 @@ import (
)
const (
nasDeleteTimeout = 30 * time.Second
// Chunk size for NAS uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for NAS to confirm receipt.
@@ -193,7 +195,10 @@ func (n *NASStorage) GetFile(
}
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
defer cancel()
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -211,10 +216,8 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
filePath := n.getFilePath(fileID.String())
// Check if file exists before trying to delete
_, err = fs.Stat(filePath)
if err != nil {
// File doesn't exist, consider it already deleted
return nil
}

View File

@@ -22,6 +22,7 @@ import (
const (
rcloneOperationTimeout = 30 * time.Second
rcloneDeleteTimeout = 30 * time.Second
)
var rcloneConfigMu sync.Mutex
@@ -115,7 +116,8 @@ func (r *RcloneStorage) GetFile(
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
defer cancel()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {

View File

@@ -26,6 +26,7 @@ const (
s3ResponseTimeout = 30 * time.Second
s3IdleConnTimeout = 90 * time.Second
s3TLSHandshakeTimeout = 30 * time.Second
s3DeleteTimeout = 30 * time.Second
// Chunk size for multipart uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
@@ -228,9 +229,11 @@ func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.
objectKey := s.buildObjectKey(fileID.String())
// Delete the object using MinIO client
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
defer cancel()
err = client.RemoveObject(
context.TODO(),
ctx,
s.S3Bucket,
objectKey,
minio.RemoveObjectOptions{},

View File

@@ -19,6 +19,7 @@ import (
const (
sftpConnectTimeout = 30 * time.Second
sftpTestConnectTimeout = 10 * time.Second
sftpDeleteTimeout = 30 * time.Second
)
type SFTPStorage struct {
@@ -154,7 +155,10 @@ func (s *SFTPStorage) GetFile(
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
defer cancel()
client, sshConn, err := s.connectWithContext(ctx, encryptor, sftpDeleteTimeout)
if err != nil {
return fmt.Errorf("failed to connect to SFTP: %w", err)
}

View File

@@ -32,21 +32,23 @@ import (
test_utils "databasus-backend/internal/util/testing"
)
const createAndFillTableQuery = `
DROP TABLE IF EXISTS test_data;
func createAndFillTableQuery(tableName string) string {
return fmt.Sprintf(`
DROP TABLE IF EXISTS %s;
CREATE TABLE test_data (
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
value INTEGER NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
INSERT INTO test_data (name, value) VALUES
INSERT INTO %s (name, value) VALUES
('test1', 100),
('test2', 200),
('test3', 300);
`
`, tableName, tableName, tableName)
}
type PostgresContainer struct {
Host string
@@ -378,9 +380,14 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -436,12 +443,19 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cp
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -875,9 +889,14 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("ReadOnly Test Workspace", user, router)
@@ -941,12 +960,19 @@ func testBackupRestoreWithReadOnlyUserForVersion(t *testing.T, pgVersion string,
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -1106,9 +1132,14 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
tableName := fmt.Sprintf("test_data_%s", uuid.New().String()[:8])
_, err = container.DB.Exec(createAndFillTableQuery(tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
}()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
@@ -1163,12 +1194,19 @@ func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, p
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
fmt.Sprintf(
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = '%s')",
tableName,
),
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
assert.True(
t,
tableExists,
fmt.Sprintf("Table '%s' should exist in restored database", tableName),
)
verifyDataIntegrity(t, container.DB, newDB)
verifyDataIntegrity(t, container.DB, newDB, tableName)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
@@ -1630,14 +1668,14 @@ func createSupabaseRestoreViaAPI(
)
}
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB, tableName string) {
var originalData []TestDataItem
var restoredData []TestDataItem
err := originalDB.Select(&originalData, "SELECT * FROM test_data ORDER BY id")
err := originalDB.Select(&originalData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
err = restoredDB.Select(&restoredData, "SELECT * FROM test_data ORDER BY id")
err = restoredDB.Select(&restoredData, fmt.Sprintf("SELECT * FROM %s ORDER BY id", tableName))
assert.NoError(t, err)
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")

View File

@@ -56,11 +56,17 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
// If user exists with INVITED status, activate them and set password
if existingUser != nil && existingUser.Status == users_enums.UserStatusInvited {
if err := s.userRepository.UpdateUserPassword(existingUser.ID, hashedPasswordStr); err != nil {
if err := s.userRepository.UpdateUserPassword(
existingUser.ID,
hashedPasswordStr,
); err != nil {
return fmt.Errorf("failed to set password: %w", err)
}
if err := s.userRepository.UpdateUserStatus(existingUser.ID, users_enums.UserStatusActive); err != nil {
if err := s.userRepository.UpdateUserStatus(
existingUser.ID,
users_enums.UserStatusActive,
); err != nil {
return fmt.Errorf("failed to activate user: %w", err)
}
@@ -635,7 +641,10 @@ func (s *UserService) getOrCreateUserFromOAuth(
if userByEmail != nil {
if userByEmail.Status == users_enums.UserStatusInvited {
if err := s.userRepository.UpdateUserStatus(userByEmail.ID, users_enums.UserStatusActive); err != nil {
if err := s.userRepository.UpdateUserStatus(
userByEmail.ID,
users_enums.UserStatusActive,
); err != nil {
return nil, fmt.Errorf("failed to activate user: %w", err)
}

View File

@@ -161,7 +161,12 @@ func (c *MembershipController) ChangeMemberRole(ctx *gin.Context) {
return
}
if err := c.membershipService.ChangeMemberRole(workspaceID, userID, &request, user); err != nil {
if err := c.membershipService.ChangeMemberRole(
workspaceID,
userID,
&request,
user,
); err != nil {
if errors.Is(err, workspaces_errors.ErrInsufficientPermissionsToManageMembers) ||
errors.Is(err, workspaces_errors.ErrOnlyOwnerCanAddManageAdmins) {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})

View File

@@ -123,7 +123,11 @@ func Test_GetWorkspaceMembers_PermissionsEnforced(t *testing.T) {
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(resp.Body), "insufficient permissions to view workspace members")
assert.Contains(
t,
string(resp.Body),
"insufficient permissions to view workspace members",
)
}
})
}
@@ -1202,7 +1206,11 @@ func Test_TransferWorkspaceOwnership_PermissionsEnforced(t *testing.T) {
if tt.expectSuccess {
assert.Contains(t, string(resp.Body), "Ownership transferred successfully")
} else {
assert.Contains(t, string(resp.Body), "only workspace owner or admin can transfer ownership")
assert.Contains(
t,
string(resp.Body),
"only workspace owner or admin can transfer ownership",
)
}
})
}

View File

@@ -100,7 +100,11 @@ func Test_CreateWorkspace_PermissionsEnforced(t *testing.T) {
request,
tt.expectedStatusCode,
)
assert.Contains(t, string(resp.Body), "insufficient permissions to create workspaces")
assert.Contains(
t,
string(resp.Body),
"insufficient permissions to create workspaces",
)
}
})
}
@@ -263,7 +267,13 @@ func Test_GetSingleWorkspace_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -365,7 +375,13 @@ func Test_UpdateWorkspace_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -396,7 +412,11 @@ func Test_UpdateWorkspace_PermissionsEnforced(t *testing.T) {
updateRequest,
tt.expectedStatusCode,
)
assert.Contains(t, string(resp.Body), "insufficient permissions to update workspace")
assert.Contains(
t,
string(resp.Body),
"insufficient permissions to update workspace",
)
}
})
}
@@ -461,7 +481,13 @@ func Test_DeleteWorkspace_PermissionsEnforced(t *testing.T) {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
*tt.workspaceRole,
owner.Token,
router,
)
testUserToken = member.Token
}
@@ -475,7 +501,11 @@ func Test_DeleteWorkspace_PermissionsEnforced(t *testing.T) {
if tt.expectSuccess {
assert.Contains(t, string(resp.Body), "Workspace deleted successfully")
} else {
assert.Contains(t, string(resp.Body), "only workspace owner or admin can delete workspace")
assert.Contains(
t,
string(resp.Body),
"only workspace owner or admin can delete workspace",
)
}
})
}

View File

@@ -173,7 +173,11 @@ func (s *MembershipService) ChangeMemberRole(
return workspaces_errors.ErrUserNotFound
}
if err := s.membershipRepository.UpdateMemberRole(memberUserID, workspaceID, request.Role); err != nil {
if err := s.membershipRepository.UpdateMemberRole(
memberUserID,
workspaceID,
request.Role,
); err != nil {
return fmt.Errorf("failed to update member role: %w", err)
}
@@ -283,11 +287,19 @@ func (s *MembershipService) TransferOwnership(
return workspaces_errors.ErrNoCurrentWorkspaceOwner
}
if err := s.membershipRepository.UpdateMemberRole(newOwner.ID, workspaceID, users_enums.WorkspaceRoleOwner); err != nil {
if err := s.membershipRepository.UpdateMemberRole(
newOwner.ID,
workspaceID,
users_enums.WorkspaceRoleOwner,
); err != nil {
return fmt.Errorf("failed to update new owner role: %w", err)
}
if err := s.membershipRepository.UpdateMemberRole(currentOwner.UserID, workspaceID, users_enums.WorkspaceRoleAdmin); err != nil {
if err := s.membershipRepository.UpdateMemberRole(
currentOwner.UserID,
workspaceID,
users_enums.WorkspaceRoleAdmin,
); err != nil {
return fmt.Errorf("failed to update previous owner role: %w", err)
}

View File

@@ -0,0 +1,59 @@
package encryption
import (
"crypto/rand"
"math/big"
)
// GenerateComplexPassword creates a password that meets common cloud provider requirements:
// - At least one lowercase letter
// - At least one uppercase letter
// - At least one digit
// - At least one special character
// - 24 characters for security
func GenerateComplexPassword() string {
const (
lowercase = "abcdefghijklmnopqrstuvwxyz"
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
digits = "0123456789"
special = "!@#$%^&*()-_=+"
all = lowercase + uppercase + digits + special
)
password := make([]byte, 24)
// Ensure at least one character from each required set
password[0] = randomChar(lowercase)
password[1] = randomChar(uppercase)
password[2] = randomChar(digits)
password[3] = randomChar(special)
// Fill the rest with random characters from all sets
for i := 4; i < len(password); i++ {
password[i] = randomChar(all)
}
// Shuffle the password to avoid predictable positions
shuffleBytes(password)
return string(password)
}
func randomChar(charset string) byte {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return charset[0]
}
return charset[n.Int64()]
}
func shuffleBytes(b []byte) {
for i := len(b) - 1; i > 0; i-- {
n, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil {
continue
}
j := n.Int64()
b[i], b[j] = b[j], b[i]
}
}

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
DROP COLUMN is_exclude_events;
-- +goose StatementEnd

View File

@@ -0,0 +1,23 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mysql_databases
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
ALTER TABLE mariadb_databases
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
ALTER TABLE mariadb_databases
DROP COLUMN is_exclude_events;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
ALTER TABLE mariadb_databases
DROP COLUMN privileges;
ALTER TABLE mysql_databases
DROP COLUMN privileges;
-- +goose StatementEnd

View File

@@ -0,0 +1,43 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE notifiers
DROP CONSTRAINT fk_notifiers_workspace_id;
ALTER TABLE notifiers
ADD CONSTRAINT fk_notifiers_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id)
ON DELETE CASCADE;
ALTER TABLE storages
DROP CONSTRAINT fk_storages_workspace_id;
ALTER TABLE storages
ADD CONSTRAINT fk_storages_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id)
ON DELETE CASCADE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE notifiers
DROP CONSTRAINT fk_notifiers_workspace_id;
ALTER TABLE notifiers
ADD CONSTRAINT fk_notifiers_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id);
ALTER TABLE storages
DROP CONSTRAINT fk_storages_workspace_id;
ALTER TABLE storages
ADD CONSTRAINT fk_storages_workspace_id
FOREIGN KEY (workspace_id)
REFERENCES workspaces (id);
-- +goose StatementEnd

View File

@@ -0,0 +1,44 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE download_tokens (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
token TEXT NOT NULL UNIQUE,
backup_id UUID NOT NULL,
user_id UUID NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
used BOOLEAN NOT NULL DEFAULT FALSE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
ALTER TABLE download_tokens
ADD CONSTRAINT fk_download_tokens_backup_id
FOREIGN KEY (backup_id)
REFERENCES backups (id)
ON DELETE CASCADE;
ALTER TABLE download_tokens
ADD CONSTRAINT fk_download_tokens_user_id
FOREIGN KEY (user_id)
REFERENCES users (id)
ON DELETE CASCADE;
CREATE INDEX idx_download_tokens_token ON download_tokens (token);
CREATE INDEX idx_download_tokens_expires_at ON download_tokens (expires_at);
CREATE INDEX idx_download_tokens_backup_id ON download_tokens (backup_id);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP INDEX IF EXISTS idx_download_tokens_backup_id;
DROP INDEX IF EXISTS idx_download_tokens_expires_at;
DROP INDEX IF EXISTS idx_download_tokens_token;
ALTER TABLE download_tokens DROP CONSTRAINT IF EXISTS fk_download_tokens_user_id;
ALTER TABLE download_tokens DROP CONSTRAINT IF EXISTS fk_download_tokens_backup_id;
DROP TABLE IF EXISTS download_tokens;
-- +goose StatementEnd

View File

@@ -9,13 +9,13 @@ echo "Installing PostgreSQL, MySQL, MariaDB and MongoDB client tools for Linux (
echo
# Check if running on supported system
if ! command -v apt-get &> /dev/null; then
if ! command -v apt-get > /dev/null 2>&1; then
echo "Error: This script requires apt-get (Debian/Ubuntu-like system)"
exit 1
fi
# Check if running as root or with sudo
if [[ $EUID -eq 0 ]]; then
if [ $EUID -eq 0 ]; then
SUDO=""
else
SUDO="sudo"
@@ -107,6 +107,12 @@ for version in $mysql_versions; do
version_dir="$MYSQL_DIR/mysql-$version"
mkdir -p "$version_dir/bin"
# Skip if already exists
if [ -f "$version_dir/bin/mysqldump" ]; then
echo " MySQL $version already installed, skipping..."
continue
fi
# Download MySQL client tools from official CDN
# Note: 5.7 is in Downloads, 8.0, 8.4 specific versions are in archives, 9.5 is in MySQL-9.5
case $version in
@@ -132,11 +138,14 @@ for version in $mysql_versions; do
wget -q "$MYSQL_URL" -O "mysql-$version.tar.gz" || wget -q "$MYSQL_URL" -O "mysql-$version.tar.xz"
echo " Extracting MySQL $version..."
if [[ "$MYSQL_URL" == *.xz ]]; then
tar -xJf "mysql-$version.tar.xz" 2>/dev/null || tar -xJf "mysql-$version.tar.gz" 2>/dev/null
else
tar -xzf "mysql-$version.tar.gz" 2>/dev/null || tar -xzf "mysql-$version.tar.xz" 2>/dev/null
fi
case "$MYSQL_URL" in
*.xz)
tar -xJf "mysql-$version.tar.xz" 2>/dev/null || tar -xJf "mysql-$version.tar.gz" 2>/dev/null
;;
*)
tar -xzf "mysql-$version.tar.gz" 2>/dev/null || tar -xzf "mysql-$version.tar.xz" 2>/dev/null
;;
esac
# Find extracted directory
EXTRACTED_DIR=$(ls -d mysql-*/ 2>/dev/null | head -1)
@@ -175,12 +184,7 @@ echo "Installing MariaDB client tools to: $MARIADB_DIR"
# Install dependencies
$SUDO apt-get install -y -qq apt-transport-https curl
# MariaDB versions to install with their URLs
declare -A MARIADB_URLS=(
["10.6"]="https://archive.mariadb.org/mariadb-10.6.21/bintar-linux-systemd-x86_64/mariadb-10.6.21-linux-systemd-x86_64.tar.gz"
["12.1"]="https://archive.mariadb.org/mariadb-12.1.2/bintar-linux-systemd-x86_64/mariadb-12.1.2-linux-systemd-x86_64.tar.gz"
)
# MariaDB versions to install
mariadb_versions="10.6 12.1"
for version in $mariadb_versions; do
@@ -195,7 +199,19 @@ for version in $mariadb_versions; do
continue
fi
url=${MARIADB_URLS[$version]}
# Get URL based on version
case "$version" in
"10.6")
url="https://archive.mariadb.org/mariadb-10.6.21/bintar-linux-systemd-x86_64/mariadb-10.6.21-linux-systemd-x86_64.tar.gz"
;;
"12.1")
url="https://archive.mariadb.org/mariadb-12.1.2/bintar-linux-systemd-x86_64/mariadb-12.1.2-linux-systemd-x86_64.tar.gz"
;;
*)
echo " Warning: Unknown MariaDB version $version"
continue
;;
esac
TEMP_DIR="/tmp/mariadb_install_$version"
mkdir -p "$TEMP_DIR"
@@ -238,43 +254,48 @@ mkdir -p "$MONGODB_DIR/bin"
echo "Installing MongoDB Database Tools to: $MONGODB_DIR"
# MongoDB Database Tools are backward compatible - single version supports all servers (4.0-8.0)
# Detect architecture
ARCH=$(uname -m)
if [ "$ARCH" = "x86_64" ]; then
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
elif [ "$ARCH" = "aarch64" ]; then
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-aarch64-100.10.0.deb"
# Skip if already installed
if [ -f "$MONGODB_DIR/bin/mongodump" ] && [ -L "$MONGODB_DIR/bin/mongodump" ]; then
echo "MongoDB Database Tools already installed, skipping..."
else
echo "Warning: Unsupported architecture $ARCH for MongoDB Database Tools"
MONGODB_TOOLS_URL=""
fi
if [ -n "$MONGODB_TOOLS_URL" ]; then
TEMP_DIR="/tmp/mongodb_install"
mkdir -p "$TEMP_DIR"
cd "$TEMP_DIR"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O mongodb-database-tools.deb || {
echo "Warning: Could not download MongoDB Database Tools"
cd - >/dev/null
rm -rf "$TEMP_DIR"
}
if [ -f "mongodb-database-tools.deb" ]; then
echo "Installing MongoDB Database Tools..."
$SUDO dpkg -i mongodb-database-tools.deb 2>/dev/null || $SUDO apt-get install -f -y -qq
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump "$MONGODB_DIR/bin/mongodump"
ln -sf /usr/bin/mongorestore "$MONGODB_DIR/bin/mongorestore"
echo "MongoDB Database Tools installed successfully"
# MongoDB Database Tools are backward compatible - single version supports all servers (4.0-8.0)
# Detect architecture
ARCH=$(uname -m)
if [ "$ARCH" = "x86_64" ]; then
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
elif [ "$ARCH" = "aarch64" ]; then
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-aarch64-100.10.0.deb"
else
echo "Warning: Unsupported architecture $ARCH for MongoDB Database Tools"
MONGODB_TOOLS_URL=""
fi
cd - >/dev/null
rm -rf "$TEMP_DIR"
if [ -n "$MONGODB_TOOLS_URL" ]; then
TEMP_DIR="/tmp/mongodb_install"
mkdir -p "$TEMP_DIR"
cd "$TEMP_DIR"
echo "Downloading MongoDB Database Tools..."
if ! wget -q "$MONGODB_TOOLS_URL" -O mongodb-database-tools.deb; then
echo "Warning: Could not download MongoDB Database Tools"
cd - >/dev/null
rm -rf "$TEMP_DIR"
else
if [ -f "mongodb-database-tools.deb" ]; then
echo "Installing MongoDB Database Tools..."
$SUDO dpkg -i mongodb-database-tools.deb 2>/dev/null || $SUDO apt-get install -f -y -qq
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump "$MONGODB_DIR/bin/mongodump"
ln -sf /usr/bin/mongorestore "$MONGODB_DIR/bin/mongorestore"
echo "MongoDB Database Tools installed successfully"
fi
cd - >/dev/null
rm -rf "$TEMP_DIR"
fi
fi
fi
echo

View File

@@ -29,23 +29,25 @@ export const backupsApi = {
return apiHelper.fetchDeleteRaw(`${getApplicationServer()}/api/v1/backups/${id}`);
},
async downloadBackup(id: string): Promise<{ blob: Blob; filename: string }> {
const result = await apiHelper.fetchGetBlobWithHeaders(
`${getApplicationServer()}/api/v1/backups/${id}/file`,
);
async downloadBackup(id: string): Promise<void> {
// Generate short-lived download token
const tokenResponse = await apiHelper.fetchPostJson<{
token: string;
filename: string;
backupId: string;
}>(`${getApplicationServer()}/api/v1/backups/${id}/download-token`, new RequestOptions());
// Extract filename from Content-Disposition header
const contentDisposition = result.headers.get('Content-Disposition');
let filename = `backup_${id}.backup`; // fallback filename
// Create direct download link with token
const downloadUrl = `${getApplicationServer()}/api/v1/backups/${id}/file?token=${tokenResponse.token}`;
if (contentDisposition) {
const filenameMatch = contentDisposition.match(/filename="?(.+?)"?$/);
if (filenameMatch && filenameMatch[1]) {
filename = filenameMatch[1];
}
}
const link = document.createElement('a');
link.href = downloadUrl;
link.download = tokenResponse.filename;
link.style.display = 'none';
return { blob: result.blob, filename };
document.body.appendChild(link);
link.click();
document.body.removeChild(link);
},
async cancelBackup(id: string) {

View File

@@ -1,3 +1,4 @@
export interface IsReadOnlyResponse {
isReadOnly: boolean;
privileges: string[];
}

View File

@@ -9,4 +9,5 @@ export interface MariadbDatabase {
password: string;
database?: string;
isHttps: boolean;
isExcludeEvents?: boolean;
}

View File

@@ -2,4 +2,5 @@ export interface GoogleDriveStorage {
clientId: string;
clientSecret: string;
tokenJson?: string;
useLocalRedirect?: boolean;
}

View File

@@ -64,21 +64,7 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
const downloadBackup = async (backupId: string) => {
try {
const { blob, filename } = await backupsApi.downloadBackup(backupId);
// Create a download link
const url = window.URL.createObjectURL(blob);
const link = document.createElement('a');
link.href = url;
link.download = filename;
// Trigger download
document.body.appendChild(link);
link.click();
// Cleanup
document.body.removeChild(link);
window.URL.revokeObjectURL(url);
await backupsApi.downloadBackup(backupId);
} catch (e) {
alert((e as Error).message);
} finally {

View File

@@ -163,6 +163,10 @@ export const CreateDatabaseComponent = ({ workspaceId, onCreated, onClose }: Pro
}
if (step === 'notifiers') {
if (isCreating) {
return <div>Creating database...</div>;
}
return (
<EditDatabaseNotifiersComponent
database={database}

View File

@@ -121,6 +121,7 @@ export const DatabaseConfigComponent = ({
const remove = () => {
if (!database) return;
setIsShowRemoveConfirm(false);
setIsRemoving(true);
databaseApi
.deleteDatabase(database.id)
@@ -165,7 +166,18 @@ export const DatabaseConfigComponent = ({
};
return (
<div className="w-full rounded-tr-md rounded-br-md rounded-bl-md bg-white p-3 shadow sm:p-5 dark:bg-gray-800">
<div className="relative w-full rounded-tr-md rounded-br-md rounded-bl-md bg-white p-3 shadow sm:p-5 dark:bg-gray-800">
{isRemoving && (
<div className="absolute inset-0 z-10 flex items-center justify-center rounded-tr-md rounded-br-md rounded-bl-md bg-white/80 dark:bg-gray-800/80">
<div className="flex flex-col items-center gap-3">
<div className="h-8 w-8 animate-spin rounded-full border-4 border-gray-300 border-t-blue-500" />
<span className="text-sm font-medium text-gray-600 dark:text-gray-300">
Removing database...
</span>
</div>
</div>
)}
{!isEditName ? (
<div className="mb-5 flex items-center text-xl font-bold sm:text-2xl">
{database.name}

View File

@@ -11,6 +11,8 @@ interface Props {
onContinue: () => void;
}
const PRIVILEGES_TRUNCATE_LENGTH = 50;
export const CreateReadOnlyComponent = ({
database,
onReadOnlyUserUpdated,
@@ -20,6 +22,8 @@ export const CreateReadOnlyComponent = ({
const [isCheckingReadOnlyUser, setIsCheckingReadOnlyUser] = useState(false);
const [isCreatingReadOnlyUser, setIsCreatingReadOnlyUser] = useState(false);
const [isShowSkipConfirmation, setShowSkipConfirmation] = useState(false);
const [privileges, setPrivileges] = useState<string[]>([]);
const [isPrivilegesExpanded, setIsPrivilegesExpanded] = useState(false);
const isPostgres = database.type === DatabaseType.POSTGRES;
const isMysql = database.type === DatabaseType.MYSQL;
@@ -35,9 +39,12 @@ export const CreateReadOnlyComponent = ({
? 'MongoDB'
: 'database';
const privilegesLabel = isMongodb ? 'roles' : 'privileges';
const checkReadOnlyUser = async (): Promise<boolean> => {
try {
const response = await databaseApi.isUserReadOnly(database);
setPrivileges(response.privileges || []);
return response.isReadOnly;
} catch (e) {
alert((e as Error).message);
@@ -45,6 +52,20 @@ export const CreateReadOnlyComponent = ({
}
};
const getPrivilegesDisplay = () => {
const fullText = privileges.join(', ');
if (isPrivilegesExpanded || fullText.length <= PRIVILEGES_TRUNCATE_LENGTH) {
return fullText;
}
return fullText.substring(0, PRIVILEGES_TRUNCATE_LENGTH) + '...';
};
const shouldShowExpandToggle = () => {
const fullText = privileges.join(', ');
return fullText.length > PRIVILEGES_TRUNCATE_LENGTH;
};
const createReadOnlyUser = async () => {
setIsCreatingReadOnlyUser(true);
@@ -139,6 +160,31 @@ export const CreateReadOnlyComponent = ({
<b>Read-only user allows to avoid storing credentials with write access at all</b>. Even
in the worst case of hacking, nobody will be able to corrupt your data.
</p>
<p className="mt-3">
{privileges.length === 0 ? (
<>
Current user has <b>no write {privilegesLabel}</b>.
</>
) : (
<>
Current user has the following write {privilegesLabel}:{' '}
<span
className={shouldShowExpandToggle() ? 'cursor-pointer hover:opacity-80' : ''}
onClick={() =>
shouldShowExpandToggle() && setIsPrivilegesExpanded(!isPrivilegesExpanded)
}
>
{getPrivilegesDisplay()}
{shouldShowExpandToggle() && (
<span className="ml-1 text-xs text-blue-600 hover:opacity-80">
({isPrivilegesExpanded ? 'collapse' : 'expand'})
</span>
)}
</span>
</>
)}
</p>
</div>
<div className="mt-5 flex">

View File

@@ -1,5 +1,5 @@
import { CopyOutlined } from '@ant-design/icons';
import { App, Button, Input, InputNumber, Switch } from 'antd';
import { CopyOutlined, DownOutlined, InfoCircleOutlined, UpOutlined } from '@ant-design/icons';
import { App, Button, Checkbox, Input, InputNumber, Switch, Tooltip } from 'antd';
import { useEffect, useState } from 'react';
import { type Database, databaseApi } from '../../../../entity/databases';
@@ -45,6 +45,9 @@ export const EditMariaDbSpecificDataComponent = ({
const [isTestingConnection, setIsTestingConnection] = useState(false);
const [isConnectionFailed, setIsConnectionFailed] = useState(false);
const hasAdvancedValues = !!database.mariadb?.isExcludeEvents;
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const parseFromClipboard = async () => {
try {
const text = await navigator.clipboard.readText();
@@ -297,7 +300,7 @@ export const EditMariaDbSpecificDataComponent = ({
</div>
)}
<div className="mb-3 flex w-full items-center">
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Use HTTPS</div>
<Switch
checked={editingDatabase.mariadb?.isHttps}
@@ -314,6 +317,52 @@ export const EditMariaDbSpecificDataComponent = ({
/>
</div>
<div className="mt-4 mb-1 flex items-center">
<div
className="flex cursor-pointer items-center text-sm text-blue-600 hover:text-blue-800"
onClick={() => setShowAdvanced(!isShowAdvanced)}
>
<span className="mr-2">Advanced settings</span>
{isShowAdvanced ? (
<UpOutlined style={{ fontSize: '12px' }} />
) : (
<DownOutlined style={{ fontSize: '12px' }} />
)}
</div>
</div>
{isShowAdvanced && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Exclude events</div>
<div className="flex items-center">
<Checkbox
checked={editingDatabase.mariadb?.isExcludeEvents || false}
onChange={(e) => {
if (!editingDatabase.mariadb) return;
setEditingDatabase({
...editingDatabase,
mariadb: {
...editingDatabase.mariadb,
isExcludeEvents: e.target.checked,
},
});
}}
>
Skip events
</Checkbox>
<Tooltip
className="cursor-pointer"
title="Skip backing up database events. Enable this if the event scheduler is disabled on your MariaDB server."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
)}
<div className="mt-5 flex">
{isShowCancelButton && (
<Button className="mr-1" danger ghost onClick={() => onCancel()}>

View File

@@ -55,6 +55,13 @@ export const ShowMariaDbSpecificDataComponent = ({ database }: Props) => {
<div className="min-w-[150px]">Use HTTPS</div>
<div>{database.mariadb?.isHttps ? 'Yes' : 'No'}</div>
</div>
{database.mariadb?.isExcludeEvents && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Exclude events</div>
<div>Yes</div>
</div>
)}
</div>
);
};

View File

@@ -163,7 +163,7 @@ export const RestoresComponent = ({ database, backup }: Props) => {
loading={isRestoreInProgress}
onClick={() => setIsShowRestore(true)}
>
Restore from backup
Select database to restore to
</Button>
{restores.length === 0 && (

View File

@@ -1,4 +1,6 @@
import { Button, Input } from 'antd';
import { DownOutlined, InfoCircleOutlined, UpOutlined } from '@ant-design/icons';
import { Button, Checkbox, Input, Tooltip } from 'antd';
import { useState } from 'react';
import { GOOGLE_DRIVE_OAUTH_REDIRECT_URL } from '../../../../../constants';
import type { Storage } from '../../../../../entity/storages';
@@ -11,12 +13,17 @@ interface Props {
}
export function EditGoogleDriveStorageComponent({ storage, setStorage, setUnsaved }: Props) {
const hasAdvancedValues = !!storage?.googleDriveStorage?.useLocalRedirect;
const [showAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const goToAuthUrl = () => {
if (!storage?.googleDriveStorage?.clientId || !storage?.googleDriveStorage?.clientSecret) {
return;
}
const redirectUri = GOOGLE_DRIVE_OAUTH_REDIRECT_URL;
const localRedirectUri = `${window.location.origin}/storages/google-oauth`;
const useLocal = storage.googleDriveStorage.useLocalRedirect;
const redirectUri = useLocal ? localRedirectUri : GOOGLE_DRIVE_OAUTH_REDIRECT_URL;
const clientId = storage.googleDriveStorage.clientId;
const scope = 'https://www.googleapis.com/auth/drive.file';
const originUrl = `${window.location.origin}/storages/google-oauth`;
@@ -92,6 +99,53 @@ export function EditGoogleDriveStorageComponent({ storage, setStorage, setUnsave
/>
</div>
<div className="mt-4 mb-3 flex items-center">
<div
className="flex cursor-pointer items-center text-sm text-blue-600 hover:text-blue-800"
onClick={() => setShowAdvanced(!showAdvanced)}
>
<span className="mr-2">Advanced settings</span>
{showAdvanced ? (
<UpOutlined style={{ fontSize: '12px' }} />
) : (
<DownOutlined style={{ fontSize: '12px' }} />
)}
</div>
</div>
{showAdvanced && (
<div className="mb-4 flex w-full flex-col items-start sm:flex-row sm:items-center">
<div className="flex items-center">
<Checkbox
checked={storage?.googleDriveStorage?.useLocalRedirect || false}
onChange={(e) => {
if (!storage?.googleDriveStorage) return;
setStorage({
...storage,
googleDriveStorage: {
...storage.googleDriveStorage,
useLocalRedirect: e.target.checked,
},
});
setUnsaved();
}}
disabled={!!storage?.googleDriveStorage?.tokenJson}
>
<span>Use local redirect</span>
</Checkbox>
<Tooltip
className="cursor-pointer"
title="When enabled, uses your address as the origin and redirect URL (specify it in Google Cloud Console). HTTPS is required."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
)}
{storage?.googleDriveStorage?.tokenJson && (
<>
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">

View File

@@ -18,6 +18,8 @@ export function OauthStorageComponent() {
const { clientId, clientSecret } = oauthDto.storage.googleDriveStorage;
const { authCode } = oauthDto;
const redirectUri = oauthDto.redirectUrl || GOOGLE_DRIVE_OAUTH_REDIRECT_URL;
try {
// Exchange authorization code for access token
const response = await fetch('https://oauth2.googleapis.com/token', {
@@ -29,13 +31,16 @@ export function OauthStorageComponent() {
code: authCode,
client_id: clientId,
client_secret: clientSecret,
redirect_uri: GOOGLE_DRIVE_OAUTH_REDIRECT_URL,
redirect_uri: redirectUri,
grant_type: 'authorization_code',
}),
});
if (!response.ok) {
throw new Error(`OAuth exchange failed: ${response.statusText}`);
const errorData = await response.json();
throw new Error(
errorData.error_description || `OAuth exchange failed: ${response.statusText}`,
);
}
const tokenData = await response.json();
@@ -44,27 +49,71 @@ export function OauthStorageComponent() {
setStorage(oauthDto.storage);
} catch (error) {
alert(`Failed to exchange OAuth code: ${error}`);
// Return to home if exchange fails
setTimeout(() => {
window.location.href = '/';
}, 3000);
}
};
useEffect(() => {
const oauthDtoParam = new URLSearchParams(window.location.search).get('oauthDto');
if (!oauthDtoParam) {
alert('OAuth param not found');
return;
}
const decodedParam = decodeURIComponent(oauthDtoParam);
const oauthDto: StorageOauthDto = JSON.parse(decodedParam);
/**
* Helper to validate the DTO and start the exchange process
*/
const processOauthDto = (oauthDto: StorageOauthDto) => {
if (oauthDto.storage.type === StorageType.GOOGLE_DRIVE) {
if (!oauthDto.storage.googleDriveStorage) {
alert('Google Drive storage not found');
alert('Google Drive storage configuration not found in DTO');
return;
}
exchangeGoogleOauthCode(oauthDto);
} else {
alert('Unsupported storage type for OAuth');
}
};
useEffect(() => {
const urlParams = new URLSearchParams(window.location.search);
// Attempt 1: Check for the 'oauthDto' param (Third-party/Legacy way)
const oauthDtoParam = urlParams.get('oauthDto');
if (oauthDtoParam) {
try {
const decodedParam = decodeURIComponent(oauthDtoParam);
const oauthDto: StorageOauthDto = JSON.parse(decodedParam);
processOauthDto(oauthDto);
return;
} catch (e) {
console.error('Error parsing oauthDto parameter:', e);
alert('Malformed OAuth parameter received');
return;
}
}
// Attempt 2: Check for 'code' and 'state' (Direct Google/Local way)
const code = urlParams.get('code');
const state = urlParams.get('state');
if (code && state) {
try {
// The 'state' parameter contains our stringified StorageOauthDto
const decodedState = decodeURIComponent(state);
const oauthDto: StorageOauthDto = JSON.parse(decodedState);
// Inject the authorization code received from Google
oauthDto.authCode = code;
processOauthDto(oauthDto);
return;
} catch (e) {
console.error('Error parsing OAuth state:', e);
alert('OAuth state parameter is invalid');
return;
}
}
// Attempt 3: No valid parameters found
alert('OAuth param not found. Ensure the redirect URL is configured correctly.');
}, []);
if (!storage) {