mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 | ||
|
|
e7c4ee8f6f | ||
|
|
a75702a01b | ||
|
|
81a21eb907 | ||
|
|
33d6bf0147 | ||
|
|
6eb53bb07b | ||
|
|
6ac04270b9 | ||
|
|
b0510d7c21 | ||
|
|
dc5f271882 | ||
|
|
8f718771c9 |
256
AGENTS.md
256
AGENTS.md
@@ -1,35 +1,37 @@
|
||||
# Agent Rules and Guidelines
|
||||
|
||||
This document contains all coding standards, conventions and best practices recommended for the Databasus project.
|
||||
This document contains all coding standards, conventions and best practices recommended for the TgTaps project.
|
||||
This is NOT a strict set of rules, but a set of recommendations to help you write better code.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Engineering Philosophy](#engineering-philosophy)
|
||||
- [Backend Guidelines](#backend-guidelines)
|
||||
- [Code Style](#code-style)
|
||||
- [Engineering philosophy](#engineering-philosophy)
|
||||
- [Backend guidelines](#backend-guidelines)
|
||||
- [Code style](#code-style)
|
||||
- [Boolean naming](#boolean-naming)
|
||||
- [Add reasonable new lines between logical statements](#add-reasonable-new-lines-between-logical-statements)
|
||||
- [Comments](#comments)
|
||||
- [Controllers](#controllers)
|
||||
- [Dependency Injection (DI)](#dependency-injection-di)
|
||||
- [Dependency injection (DI)](#dependency-injection-di)
|
||||
- [Migrations](#migrations)
|
||||
- [Refactoring](#refactoring)
|
||||
- [Testing](#testing)
|
||||
- [Time Handling](#time-handling)
|
||||
- [CRUD Examples](#crud-examples)
|
||||
- [Frontend Guidelines](#frontend-guidelines)
|
||||
- [React Component Structure](#react-component-structure)
|
||||
- [Time handling](#time-handling)
|
||||
- [CRUD examples](#crud-examples)
|
||||
- [Frontend guidelines](#frontend-guidelines)
|
||||
- [React component structure](#react-component-structure)
|
||||
|
||||
---
|
||||
|
||||
## Engineering Philosophy
|
||||
## Engineering philosophy
|
||||
|
||||
**Think like a skeptical senior engineer and code reviewer. Don't just do what was asked—also think about what should have been asked.**
|
||||
|
||||
⚠️ **Balance vigilance with pragmatism:** Catch real issues, not theoretical ones. Don't let perfect be the enemy of good.
|
||||
|
||||
### Task Context Assessment:
|
||||
### Task context assessment:
|
||||
|
||||
**First, assess the task scope:**
|
||||
|
||||
@@ -38,7 +40,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
- **Complex** (architecture, security, performance-critical): Full analysis required
|
||||
- **Unclear** (ambiguous requirements): Always clarify assumptions first
|
||||
|
||||
### For Non-Trivial Tasks:
|
||||
### For non-trivial tasks:
|
||||
|
||||
1. **Restate the objective and list assumptions** (explicit + implicit)
|
||||
- If any assumption is shaky, call it out clearly
|
||||
@@ -71,7 +73,7 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
- Patch the answer accordingly
|
||||
- Verify edge cases are handled
|
||||
|
||||
### Application Guidelines:
|
||||
### Application guidelines:
|
||||
|
||||
**Scale your response to the task:**
|
||||
|
||||
@@ -84,9 +86,9 @@ This is NOT a strict set of rules, but a set of recommendations to help you writ
|
||||
|
||||
---
|
||||
|
||||
## Backend Guidelines
|
||||
## Backend guidelines
|
||||
|
||||
### Code Style
|
||||
### Code style
|
||||
|
||||
**Always place private methods to the bottom of file**
|
||||
|
||||
@@ -94,7 +96,7 @@ This rule applies to ALL Go files including tests, services, controllers, reposi
|
||||
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
#### Structure Order:
|
||||
#### Structure order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
@@ -227,7 +229,7 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
}
|
||||
```
|
||||
|
||||
#### Key Points:
|
||||
#### Key points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
@@ -237,13 +239,13 @@ func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
|
||||
---
|
||||
|
||||
### Boolean Naming
|
||||
### Boolean naming
|
||||
|
||||
**Always prefix boolean variables with verbs like `is`, `has`, `was`, `should`, `can`, etc.**
|
||||
|
||||
This makes the code more readable and clearly indicates that the variable represents a true/false state.
|
||||
|
||||
#### Good Examples:
|
||||
#### Good examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
@@ -265,7 +267,7 @@ wasCompleted := false
|
||||
hasPermission := checkPermissions()
|
||||
```
|
||||
|
||||
#### Bad Examples:
|
||||
#### Bad examples:
|
||||
|
||||
```go
|
||||
type User struct {
|
||||
@@ -286,7 +288,7 @@ completed := false // Should be: wasCompleted
|
||||
permission := true // Should be: hasPermission
|
||||
```
|
||||
|
||||
#### Common Boolean Prefixes:
|
||||
#### Common boolean prefixes:
|
||||
|
||||
- **is** - current state (IsActive, IsValid, IsEnabled)
|
||||
- **has** - possession or presence (HasAccess, HasPermission, HasError)
|
||||
@@ -297,6 +299,167 @@ permission := true // Should be: hasPermission
|
||||
|
||||
---
|
||||
|
||||
### Add reasonable new lines between logical statements
|
||||
|
||||
**Add blank lines between logical blocks to improve code readability.**
|
||||
|
||||
Separate different logical operations within a function with blank lines. This makes the code flow clearer and helps identify distinct steps in the logic.
|
||||
|
||||
#### Guidelines:
|
||||
|
||||
- Add blank line before final `return` statement
|
||||
- Add blank line after variable declarations before using them
|
||||
- Add blank line between error handling and subsequent logic
|
||||
- Add blank line between different logical operations
|
||||
|
||||
#### Bad example (without spacing):
|
||||
|
||||
```go
|
||||
func (t *Task) BeforeSave(tx *gorm.DB) error {
|
||||
if len(t.Messages) > 0 {
|
||||
messagesBytes, err := json.Marshal(t.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.MessagesJSON = string(messagesBytes)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Task) AfterFind(tx *gorm.DB) error {
|
||||
if t.MessagesJSON != "" {
|
||||
var messages []onewin_dto.TaskCompletionMessage
|
||||
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
|
||||
return err
|
||||
}
|
||||
t.Messages = messages
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
#### Good example (with proper spacing):
|
||||
|
||||
```go
|
||||
func (t *Task) BeforeSave(tx *gorm.DB) error {
|
||||
if len(t.Messages) > 0 {
|
||||
messagesBytes, err := json.Marshal(t.Messages)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.MessagesJSON = string(messagesBytes)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Task) AfterFind(tx *gorm.DB) error {
|
||||
if t.MessagesJSON != "" {
|
||||
var messages []onewin_dto.TaskCompletionMessage
|
||||
if err := json.Unmarshal([]byte(t.MessagesJSON), &messages); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.Messages = messages
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
#### More examples:
|
||||
|
||||
**Service method with multiple operations:**
|
||||
|
||||
```go
|
||||
func (s *UserService) CreateUser(request *CreateUserRequest) (*User, error) {
|
||||
// Validate input
|
||||
if err := s.validateUserRequest(request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create user entity
|
||||
user := &User{
|
||||
ID: uuid.New(),
|
||||
Name: request.Name,
|
||||
Email: request.Email,
|
||||
}
|
||||
|
||||
// Save to database
|
||||
if err := s.repository.Create(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send notification
|
||||
s.notificationService.SendWelcomeEmail(user.Email)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Repository method with query building:**
|
||||
|
||||
```go
|
||||
func (r *Repository) GetFiltered(filters *Filters) ([]*Entity, error) {
|
||||
query := storage.GetDb().Model(&Entity{})
|
||||
|
||||
if filters.Status != "" {
|
||||
query = query.Where("status = ?", filters.Status)
|
||||
}
|
||||
|
||||
if filters.CreatedAfter != nil {
|
||||
query = query.Where("created_at > ?", filters.CreatedAfter)
|
||||
}
|
||||
|
||||
var entities []*Entity
|
||||
if err := query.Find(&entities).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return entities, nil
|
||||
}
|
||||
```
|
||||
|
||||
**Repository method with error handling:**
|
||||
|
||||
Bad (without spacing):
|
||||
|
||||
```go
|
||||
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
|
||||
var task models.Task
|
||||
result := storage.GetDb().Where("id = ?", id).First(&task)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
return nil, result.Error
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
```
|
||||
|
||||
Good (with proper spacing):
|
||||
|
||||
```go
|
||||
func (r *Repository) FindById(id uuid.UUID) (*models.Task, error) {
|
||||
var task models.Task
|
||||
|
||||
result := storage.GetDb().Where("id = ?", id).First(&task)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("task not found")
|
||||
}
|
||||
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return &task, nil
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Comments
|
||||
|
||||
#### Guidelines
|
||||
@@ -305,13 +468,14 @@ permission := true // Should be: hasPermission
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
#### Key Principles:
|
||||
#### Key principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
- **Do not write summary sections in .md files unless directly requested** - Avoid adding "Summary" or "Conclusion" sections at the end of documentation files unless the user explicitly asks for them
|
||||
|
||||
#### Example of useless comments:
|
||||
|
||||
@@ -343,7 +507,7 @@ func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemReq
|
||||
|
||||
### Controllers
|
||||
|
||||
#### Controller Guidelines:
|
||||
#### Controller guidelines:
|
||||
|
||||
1. **When we write controller:**
|
||||
- We combine all routes to single controller
|
||||
@@ -475,7 +639,7 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
|
||||
---
|
||||
|
||||
### Dependency Injection (DI)
|
||||
### Dependency injection (DI)
|
||||
|
||||
For DI files use **implicit fields declaration styles** (especially for controllers, services, repositories, use cases, etc., not simple data structures).
|
||||
|
||||
@@ -503,7 +667,7 @@ var orderController = &OrderController{
|
||||
|
||||
**This is needed to avoid forgetting to update DI style when we add new dependency.**
|
||||
|
||||
#### Force Such Usage
|
||||
#### Force such usage
|
||||
|
||||
Please force such usage if file look like this (see some services\controllers\repos definitions and getters):
|
||||
|
||||
@@ -549,13 +713,13 @@ func GetOrderRepository() *repositories.OrderRepository {
|
||||
}
|
||||
```
|
||||
|
||||
#### SetupDependencies() Pattern
|
||||
#### SetupDependencies() pattern
|
||||
|
||||
**All `SetupDependencies()` functions must use sync.Once to ensure idempotent execution.**
|
||||
|
||||
This pattern allows `SetupDependencies()` to be safely called multiple times (especially in tests) while ensuring the actual setup logic executes only once.
|
||||
|
||||
**Implementation Pattern:**
|
||||
**Implementation pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
@@ -588,7 +752,7 @@ func SetupDependencies() {
|
||||
}
|
||||
```
|
||||
|
||||
**Why This Pattern:**
|
||||
**Why this pattern:**
|
||||
|
||||
- **Tests can call multiple times**: Test setup often calls `SetupDependencies()` multiple times without issues
|
||||
- **Thread-safe**: Works correctly with concurrent calls (nanoseconds or seconds apart)
|
||||
@@ -604,13 +768,13 @@ func SetupDependencies() {
|
||||
|
||||
---
|
||||
|
||||
### Background Services
|
||||
### Background services
|
||||
|
||||
**All background service `Run()` methods must panic if called multiple times to prevent corrupted states.**
|
||||
|
||||
Background services run infinite loops and must never be started twice on the same instance. Multiple calls indicate a serious bug that would cause duplicate goroutines, resource leaks, and data corruption.
|
||||
|
||||
**Implementation Pattern:**
|
||||
**Implementation pattern:**
|
||||
|
||||
```go
|
||||
package feature
|
||||
@@ -654,14 +818,14 @@ func (s *BackgroundService) Run(ctx context.Context) {
|
||||
}
|
||||
```
|
||||
|
||||
**Why Panic Instead of Warning:**
|
||||
**Why panic instead of warning:**
|
||||
|
||||
- **Prevents corruption**: Multiple `Run()` calls would create duplicate goroutines consuming resources
|
||||
- **Fails fast**: Catches critical bugs immediately in tests and production
|
||||
- **Clear indication**: Panic clearly indicates a serious programming error
|
||||
- **Applies everywhere**: Same protection in tests and production
|
||||
|
||||
**When This Applies:**
|
||||
**When this applies:**
|
||||
|
||||
- All background services with infinite loops
|
||||
- Registry services (BackupNodesRegistry, RestoreNodesRegistry)
|
||||
@@ -727,14 +891,14 @@ You can shortify, make more readable, improve code quality, etc. Common logic ca
|
||||
|
||||
**After writing tests, always launch them and verify that they pass.**
|
||||
|
||||
#### Test Naming Format
|
||||
#### Test naming format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
#### Examples from Real Codebase:
|
||||
#### Examples from real codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
@@ -742,22 +906,22 @@ Use these naming patterns:
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
#### Testing Philosophy
|
||||
#### Testing philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
**Prefer controllers over unit tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
**Extract common logic to testing utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
**Refactor existing tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
@@ -766,7 +930,7 @@ Use these naming patterns:
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
**Clean Up Test Data:**
|
||||
**Clean up test data:**
|
||||
|
||||
- If the feature supports cleanup operations (DELETE endpoints, cleanup methods), use them in tests
|
||||
- Clean up resources after test execution to avoid test data pollution
|
||||
@@ -803,7 +967,7 @@ func Test_BackupLifecycle_CreateAndDelete(t *testing.T) {
|
||||
}
|
||||
```
|
||||
|
||||
#### Testing Utilities Structure
|
||||
#### Testing utilities structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
@@ -839,7 +1003,7 @@ func AddMemberToProject(project *projects_models.Project, member *users_dto.Sign
|
||||
}
|
||||
```
|
||||
|
||||
#### Controller Test Examples
|
||||
#### Controller test examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
@@ -906,7 +1070,7 @@ func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
|
||||
---
|
||||
|
||||
### Time Handling
|
||||
### Time handling
|
||||
|
||||
**Always use `time.Now().UTC()` instead of `time.Now()`**
|
||||
|
||||
@@ -914,7 +1078,7 @@ This ensures consistent timezone handling across the application.
|
||||
|
||||
---
|
||||
|
||||
### CRUD Examples
|
||||
### CRUD examples
|
||||
|
||||
This is an example of complete CRUD implementation structure:
|
||||
|
||||
@@ -1578,9 +1742,9 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
|
||||
|
||||
---
|
||||
|
||||
## Frontend Guidelines
|
||||
## Frontend guidelines
|
||||
|
||||
### React Component Structure
|
||||
### React component structure
|
||||
|
||||
Write React components with the following structure:
|
||||
|
||||
@@ -1614,7 +1778,7 @@ export const ReactComponent = ({ someValue }: Props): JSX.Element => {
|
||||
}
|
||||
```
|
||||
|
||||
#### Structure Order:
|
||||
#### Structure order:
|
||||
|
||||
1. **Props interface** - Define component props
|
||||
2. **Helper functions** (outside component) - Pure utility functions
|
||||
|
||||
@@ -431,6 +431,8 @@ fi
|
||||
exec ./main
|
||||
EOF
|
||||
|
||||
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
|
||||
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
EXPOSE 4005
|
||||
@@ -439,4 +441,4 @@ EXPOSE 4005
|
||||
VOLUME ["/databasus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
|
||||
@@ -103,6 +103,16 @@ func (s *BackupsScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
|
||||
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get available nodes for health check", "error", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return len(nodes) > 0
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
|
||||
@@ -1366,11 +1366,24 @@ func createTestBackup(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(loadedStorages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
// Filter out system storages
|
||||
var nonSystemStorages []*storages.Storage
|
||||
for _, storage := range loadedStorages {
|
||||
if !storage.IsSystem {
|
||||
nonSystemStorages = append(nonSystemStorages, storage)
|
||||
}
|
||||
}
|
||||
if len(nonSystemStorages) == 0 {
|
||||
panic("No non-system storage found for workspace")
|
||||
}
|
||||
|
||||
storages := nonSystemStorages
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
|
||||
@@ -108,13 +108,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
|
||||
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--routines",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
|
||||
@@ -1164,12 +1164,13 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,13 +25,14 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsExcludeEvents = incoming.IsExcludeEvents
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
|
||||
@@ -26,12 +26,13 @@ type MongodbDatabase struct {
|
||||
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Port *int `json:"port" gorm:"type:int"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database string `json:"database" gorm:"type:text;not null"`
|
||||
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
|
||||
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
|
||||
}
|
||||
|
||||
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
|
||||
if !m.IsSrv {
|
||||
if m.Port == nil || *m.Port == 0 {
|
||||
return errors.New("port is required for standard connections")
|
||||
}
|
||||
}
|
||||
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.CpuCount <= 0 {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
|
||||
m.Database = incoming.Database
|
||||
m.AuthDatabase = incoming.AuthDatabase
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsSrv = incoming.IsSrv
|
||||
m.CpuCount = incoming.CpuCount
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
|
||||
@@ -64,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -133,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -442,15 +446,17 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
port := container.Port
|
||||
return &MongodbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -489,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
|
||||
strings.Contains(errStr, "permission denied"),
|
||||
"Expected authorization error, got: %v", err)
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "port is required for standard connections")
|
||||
}
|
||||
|
||||
@@ -564,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 4: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
// Step 4: Discover schemas to grant privileges on
|
||||
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
|
||||
var rows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
AND schema_name = ANY($1::text[])
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get schemas: %w", err)
|
||||
}
|
||||
@@ -619,50 +630,197 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// First, set default privileges for objects created by the current user
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 8: Verify user creation before committing
|
||||
// Step 8: Discover all roles that own objects in each schema
|
||||
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
|
||||
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
|
||||
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
|
||||
// Filter by IncludeSchemas if specified.
|
||||
type SchemaOwner struct {
|
||||
SchemaName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
var ownerRows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname = ANY($1::text[])
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log warning but continue - this is a best-effort enhancement
|
||||
logger.Warn("Failed to query object owners for default privileges", "error", err)
|
||||
} else {
|
||||
var schemaOwners []SchemaOwner
|
||||
for ownerRows.Next() {
|
||||
var so SchemaOwner
|
||||
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
|
||||
ownerRows.Close()
|
||||
logger.Warn("Failed to scan schema owner", "error", err)
|
||||
break
|
||||
}
|
||||
schemaOwners = append(schemaOwners, so)
|
||||
}
|
||||
ownerRows.Close()
|
||||
|
||||
if err := ownerRows.Err(); err != nil {
|
||||
logger.Warn("Error iterating schema owners", "error", err)
|
||||
}
|
||||
|
||||
// Step 9: Set default privileges FOR ROLE for each object owner
|
||||
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
|
||||
// We log warnings but continue - user creation should succeed even if some roles can't be configured
|
||||
for _, so := range schemaOwners {
|
||||
// Try to set default privileges for tables
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (tables)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Try to set default privileges for sequences
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (sequences)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(schemaOwners) > 0 {
|
||||
logger.Info(
|
||||
"Set default privileges for existing object owners",
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
"owner_count",
|
||||
len(schemaOwners),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 10: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
|
||||
@@ -1319,6 +1319,346 @@ type PostgresContainer struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create a second database user who will create tables
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Step 2: Grant the user_creator privileges to connect and create tables
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CREATE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 2b: Create an initial table by user_creator so they become an object owner
|
||||
// This is important because our fix discovers existing object owners
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
initialTableName := fmt.Sprintf(
|
||||
"public.initial_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('initial_data');
|
||||
`, initialTableName, initialTableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
|
||||
}()
|
||||
|
||||
// Step 3: NOW create read-only user via Databasus (as admin)
|
||||
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
|
||||
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
|
||||
tableName := fmt.Sprintf(
|
||||
"public.future_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
|
||||
`, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
// Step 5: Connect as read-only user and verify it can SELECT from the new table
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
2,
|
||||
count,
|
||||
"Read-only user should be able to SELECT from table created by different user",
|
||||
)
|
||||
|
||||
// Step 6: Verify read-only user cannot write to the table
|
||||
_, err = readonlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify pg_dump operations (LOCK TABLE) work
|
||||
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
|
||||
tx, err := readonlyConn.Begin()
|
||||
assert.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
|
||||
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
|
||||
|
||||
err = tx.Commit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create multiple schemas and tables
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS included_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
|
||||
CREATE SCHEMA included_schema;
|
||||
CREATE SCHEMA excluded_schema;
|
||||
|
||||
CREATE TABLE public.public_table (id INT, data TEXT);
|
||||
INSERT INTO public.public_table VALUES (1, 'public_data');
|
||||
|
||||
CREATE TABLE included_schema.included_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
|
||||
}()
|
||||
|
||||
// Step 2: Create a second user who owns tables in both included and excluded schemas
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Grant privileges to user_creator
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
|
||||
schema,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// User_creator creates tables in included and excluded schemas
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.user_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
|
||||
pgModel := createPostgresModel(container)
|
||||
pgModel.IncludeSchemas = []string{"public", "included_schema"}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: Connect as read-only user
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
// Step 5: Verify read-only user CAN access included schemas
|
||||
var publicData string
|
||||
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "public_data", publicData)
|
||||
|
||||
var includedData string
|
||||
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "included_data", includedData)
|
||||
|
||||
var userIncludedData string
|
||||
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "user_included_data", userIncludedData)
|
||||
|
||||
// Step 6: Verify read-only user CANNOT access excluded schema
|
||||
var excludedData string
|
||||
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
err = readonlyConn.Get(
|
||||
&excludedData,
|
||||
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify future tables in included schemas are accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.future_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureData string
|
||||
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"future_data",
|
||||
futureData,
|
||||
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
|
||||
)
|
||||
|
||||
// Step 8: Verify future tables in excluded schema are NOT accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureExcludedData string
|
||||
err = readonlyConn.Get(
|
||||
&futureExcludedData,
|
||||
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(
|
||||
t,
|
||||
err.Error(),
|
||||
"permission denied",
|
||||
"Read-only user should NOT access tables in excluded schemas",
|
||||
)
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
|
||||
@@ -71,12 +71,13 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -358,7 +357,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup = createTestBackup(mysqlDB, owner)
|
||||
backup = createTestBackup(mysqlDB, storage)
|
||||
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysql.MysqlDatabase{
|
||||
@@ -610,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
backup := createTestBackup(database, storage)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
@@ -727,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
storage *storages.Storage,
|
||||
) *backups_core.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
@@ -759,7 +748,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
fieldEncryptor,
|
||||
logger,
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_repositories "databasus-backend/internal/features/workspaces/repositories"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
@@ -1969,6 +1970,143 @@ func Test_TransferSystemStorage_TransferBlocked(t *testing.T) {
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
|
||||
}
|
||||
|
||||
func Test_DeleteWorkspace_SystemStoragesFromAnotherWorkspaceNotRemovedAndWorkspaceDeletedSuccessfully(
|
||||
t *testing.T,
|
||||
) {
|
||||
router := createRouter()
|
||||
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", admin, router)
|
||||
workspaceD := workspaces_testing.CreateTestWorkspace("Workspace D", admin, router)
|
||||
|
||||
// Create a system storage in workspace A
|
||||
systemStorage := &Storage{
|
||||
WorkspaceID: workspaceA.ID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Test System Storage " + uuid.New().String(),
|
||||
IsSystem: true,
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
var savedSystemStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*systemStorage,
|
||||
http.StatusOK,
|
||||
&savedSystemStorage,
|
||||
)
|
||||
assert.True(t, savedSystemStorage.IsSystem)
|
||||
assert.Equal(t, workspaceA.ID, savedSystemStorage.WorkspaceID)
|
||||
|
||||
// Create a regular storage in workspace D
|
||||
regularStorage := createNewStorage(workspaceD.ID)
|
||||
var savedRegularStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*regularStorage,
|
||||
http.StatusOK,
|
||||
&savedRegularStorage,
|
||||
)
|
||||
assert.False(t, savedRegularStorage.IsSystem)
|
||||
assert.Equal(t, workspaceD.ID, savedRegularStorage.WorkspaceID)
|
||||
|
||||
// Delete workspace D
|
||||
workspaces_testing.DeleteWorkspace(workspaceD, admin.Token, router)
|
||||
|
||||
// Verify system storage from workspace A still exists
|
||||
repository := &StorageRepository{}
|
||||
systemStorageAfterDeletion, err := repository.FindByID(savedSystemStorage.ID)
|
||||
assert.NoError(t, err, "System storage should still exist after workspace D deletion")
|
||||
assert.NotNil(t, systemStorageAfterDeletion)
|
||||
assert.Equal(t, savedSystemStorage.ID, systemStorageAfterDeletion.ID)
|
||||
assert.True(t, systemStorageAfterDeletion.IsSystem)
|
||||
assert.Equal(t, workspaceA.ID, systemStorageAfterDeletion.WorkspaceID)
|
||||
|
||||
// Verify regular storage from workspace D was deleted
|
||||
regularStorageAfterDeletion, err := repository.FindByID(savedRegularStorage.ID)
|
||||
assert.Error(t, err, "Regular storage should be deleted with workspace D")
|
||||
assert.Nil(t, regularStorageAfterDeletion)
|
||||
|
||||
// Cleanup
|
||||
deleteStorage(t, router, savedSystemStorage.ID, admin.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
|
||||
}
|
||||
|
||||
func Test_DeleteWorkspace_WithOwnSystemStorage_ReturnsForbidden(t *testing.T) {
|
||||
router := createRouter()
|
||||
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", admin, router)
|
||||
|
||||
// Create a system storage assigned to workspace A
|
||||
systemStorage := &Storage{
|
||||
WorkspaceID: workspaceA.ID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "System Storage in A " + uuid.New().String(),
|
||||
IsSystem: true,
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
var savedSystemStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*systemStorage,
|
||||
http.StatusOK,
|
||||
&savedSystemStorage,
|
||||
)
|
||||
assert.True(t, savedSystemStorage.IsSystem)
|
||||
assert.Equal(t, workspaceA.ID, savedSystemStorage.WorkspaceID)
|
||||
|
||||
// Attempt to delete workspace A - should fail because it has a system storage
|
||||
resp := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"DELETE",
|
||||
"/api/v1/workspaces/"+workspaceA.ID.String(),
|
||||
"Bearer "+admin.Token,
|
||||
nil,
|
||||
)
|
||||
assert.Equal(t, http.StatusBadRequest, resp.Code, "Workspace deletion should fail")
|
||||
assert.Contains(
|
||||
t,
|
||||
resp.Body.String(),
|
||||
"system storage cannot be deleted due to workspace deletion",
|
||||
"Error message should indicate system storage prevents deletion",
|
||||
)
|
||||
|
||||
// Verify workspace still exists
|
||||
workspaceRepo := &workspaces_repositories.WorkspaceRepository{}
|
||||
workspaceAfterFailedDeletion, err := workspaceRepo.GetWorkspaceByID(workspaceA.ID)
|
||||
assert.NoError(t, err, "Workspace should still exist after failed deletion")
|
||||
assert.NotNil(t, workspaceAfterFailedDeletion)
|
||||
assert.Equal(t, workspaceA.ID, workspaceAfterFailedDeletion.ID)
|
||||
|
||||
// Verify system storage still exists
|
||||
repository := &StorageRepository{}
|
||||
storageAfterFailedDeletion, err := repository.FindByID(savedSystemStorage.ID)
|
||||
assert.NoError(t, err, "System storage should still exist after failed deletion")
|
||||
assert.NotNil(t, storageAfterFailedDeletion)
|
||||
assert.Equal(t, savedSystemStorage.ID, storageAfterFailedDeletion.ID)
|
||||
assert.True(t, storageAfterFailedDeletion.IsSystem)
|
||||
|
||||
// Cleanup: Delete system storage first, then workspace can be deleted
|
||||
deleteStorage(t, router, savedSystemStorage.ID, admin.Token)
|
||||
workspaces_testing.DeleteWorkspace(workspaceA, admin.Token, router)
|
||||
|
||||
// Verify workspace was successfully deleted after storage removal
|
||||
_, err = workspaceRepo.GetWorkspaceByID(workspaceA.ID)
|
||||
assert.Error(t, err, "Workspace should be deleted after storage was removed")
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
@@ -1983,6 +2121,7 @@ func createRouter() *gin.Engine {
|
||||
}
|
||||
|
||||
audit_logs.SetupDependencies()
|
||||
SetupDependencies()
|
||||
GetStorageService().SetStorageDatabaseCounter(&mockStorageDatabaseCounter{})
|
||||
|
||||
return router
|
||||
|
||||
@@ -25,6 +25,32 @@ func (s *StorageService) SetStorageDatabaseCounter(storageDatabaseCounter Storag
|
||||
s.storageDatabaseCounter = storageDatabaseCounter
|
||||
}
|
||||
|
||||
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get storages for workspace deletion: %w", err)
|
||||
}
|
||||
|
||||
for _, storage := range storages {
|
||||
if storage.IsSystem && storage.WorkspaceID != workspaceID {
|
||||
// skip system storage from another workspace
|
||||
continue
|
||||
}
|
||||
|
||||
if storage.IsSystem && storage.WorkspaceID == workspaceID {
|
||||
return fmt.Errorf(
|
||||
"system storage cannot be deleted due to workspace deletion, please transfer or remove storage first",
|
||||
)
|
||||
}
|
||||
|
||||
if err := s.storageRepository.Delete(storage); err != nil {
|
||||
return fmt.Errorf("failed to delete storage %s: %w", storage.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageService) SaveStorage(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
@@ -351,18 +377,3 @@ func (s *StorageService) TransferStorageToWorkspace(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get storages for workspace deletion: %w", err)
|
||||
}
|
||||
|
||||
for _, storage := range storages {
|
||||
if err := s.storageRepository.Delete(storage); err != nil {
|
||||
return fmt.Errorf("failed to delete storage %s: %w", storage.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -52,6 +52,10 @@ func (s *HealthcheckService) performHealthCheck() error {
|
||||
if !s.backupBackgroundService.IsSchedulerRunning() {
|
||||
return errors.New("backups are not running for more than 5 minutes")
|
||||
}
|
||||
|
||||
if !s.backupBackgroundService.IsBackupNodesAvailable() {
|
||||
return errors.New("no backup nodes available")
|
||||
}
|
||||
}
|
||||
|
||||
if config.GetEnv().IsProcessingNode {
|
||||
|
||||
@@ -147,6 +147,26 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadb_WithExcludeEvents_EventsNotRestored(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMariadbBackupRestoreWithExcludeEventsForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
@@ -702,3 +722,145 @@ func updateMariadbDatabaseCredentialsViaAPI(
|
||||
|
||||
return &updatedDatabase
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreWithExcludeEventsForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMariadbContainer(mariadbVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE EVENT IF NOT EXISTS test_event
|
||||
ON SCHEDULE EVERY 1 DAY
|
||||
DO BEGIN
|
||||
INSERT INTO test_data (name, value) VALUES ('event_test', 999);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Skipf(
|
||||
"Skipping test: MariaDB version doesn't support events or event scheduler disabled: %v",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MariaDB Exclude Events Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseViaAPI(
|
||||
t, router, "MariaDB Exclude Events Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
database.Mariadb.IsExcludeEvents = true
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+user.Token,
|
||||
database,
|
||||
)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf(
|
||||
"Failed to update database with IsExcludeEvents. Status: %d, Body: %s",
|
||||
w.Code,
|
||||
w.Body.String(),
|
||||
)
|
||||
}
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_no_events"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
var eventCount int
|
||||
err = newDB.Get(
|
||||
&eventCount,
|
||||
"SELECT COUNT(*) FROM information_schema.events WHERE event_schema = ? AND event_name = 'test_event'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
0,
|
||||
eventCount,
|
||||
"Event 'test_event' should NOT exist in restored database when IsExcludeEvents is true",
|
||||
)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
@@ -385,13 +385,14 @@ func createMongodbDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypeMongodb,
|
||||
Mongodb: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthDatabase: authDatabase,
|
||||
Version: version,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
@@ -432,13 +433,14 @@ func createMongodbRestoreViaAPI(
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MongodbDatabase: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthDatabase: authDatabase,
|
||||
Version: version,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -71,6 +71,7 @@ func tryInitVictoriaLogs() *VictoriaLogsWriter {
|
||||
|
||||
// Try to get config - this may fail early in startup
|
||||
url := getVictoriaLogsURL()
|
||||
username := getVictoriaLogsUsername()
|
||||
password := getVictoriaLogsPassword()
|
||||
|
||||
if url == "" {
|
||||
@@ -78,7 +79,7 @@ func tryInitVictoriaLogs() *VictoriaLogsWriter {
|
||||
return nil
|
||||
}
|
||||
|
||||
return NewVictoriaLogsWriter(url, password)
|
||||
return NewVictoriaLogsWriter(url, username, password)
|
||||
}
|
||||
|
||||
func ensureEnvLoaded() {
|
||||
@@ -126,6 +127,10 @@ func getVictoriaLogsURL() string {
|
||||
return os.Getenv("VICTORIA_LOGS_URL")
|
||||
}
|
||||
|
||||
func getVictoriaLogsUsername() string {
|
||||
return os.Getenv("VICTORIA_LOGS_USERNAME")
|
||||
}
|
||||
|
||||
func getVictoriaLogsPassword() string {
|
||||
return os.Getenv("VICTORIA_LOGS_PASSWORD")
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ type logEntry struct {
|
||||
|
||||
type VictoriaLogsWriter struct {
|
||||
url string
|
||||
username string
|
||||
password string
|
||||
httpClient *http.Client
|
||||
logChannel chan logEntry
|
||||
@@ -33,11 +34,12 @@ type VictoriaLogsWriter struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewVictoriaLogsWriter(url, password string) *VictoriaLogsWriter {
|
||||
func NewVictoriaLogsWriter(url, username, password string) *VictoriaLogsWriter {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
writer := &VictoriaLogsWriter{
|
||||
url: url,
|
||||
username: username,
|
||||
password: password,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
@@ -149,9 +151,9 @@ func (w *VictoriaLogsWriter) sendHTTP(entries []logEntry) error {
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/x-ndjson")
|
||||
|
||||
// Set Basic Auth (password as username, empty password)
|
||||
// Set Basic Auth (username:password)
|
||||
if w.password != "" {
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(w.password + ":"))
|
||||
auth := base64.StdEncoding.EncodeToString([]byte(w.username + ":" + w.password))
|
||||
req.Header.Set("Authorization", "Basic "+auth)
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ALTER COLUMN port DROP NOT NULL;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ADD COLUMN is_srv BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases DROP COLUMN is_srv;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ALTER COLUMN port SET NOT NULL;
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mariadb_databases
|
||||
ADD COLUMN IF NOT EXISTS is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mariadb_databases
|
||||
DROP COLUMN IF EXISTS is_exclude_events;
|
||||
-- +goose StatementEnd
|
||||
@@ -24,8 +24,6 @@ export function getApplicationServer() {
|
||||
}
|
||||
}
|
||||
|
||||
export const GOOGLE_DRIVE_OAUTH_REDIRECT_URL = 'https://databasus.com/storages/google-oauth';
|
||||
|
||||
export const APP_VERSION = (import.meta.env.VITE_APP_VERSION as string) || 'dev';
|
||||
|
||||
export const IS_CLOUD =
|
||||
|
||||
@@ -32,6 +32,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('mydb');
|
||||
expect(result.authDatabase).toBe('admin');
|
||||
expect(result.useTls).toBe(false);
|
||||
expect(result.isSrv).toBe(false);
|
||||
});
|
||||
|
||||
it('should parse connection string without database', () => {
|
||||
@@ -46,6 +47,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('');
|
||||
expect(result.authDatabase).toBe('admin');
|
||||
expect(result.useTls).toBe(false);
|
||||
expect(result.isSrv).toBe(false);
|
||||
});
|
||||
|
||||
it('should default port to 27017 when not specified', () => {
|
||||
@@ -107,6 +109,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.password).toBe('atlaspass');
|
||||
expect(result.database).toBe('mydb');
|
||||
expect(result.useTls).toBe(true); // SRV connections use TLS by default
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
|
||||
it('should parse mongodb+srv:// without database', () => {
|
||||
@@ -119,6 +122,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.host).toBe('cluster0.abc123.mongodb.net');
|
||||
expect(result.database).toBe('');
|
||||
expect(result.useTls).toBe(true);
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -314,13 +318,15 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.format).toBe('key-value');
|
||||
});
|
||||
|
||||
it('should return error for key-value format missing password', () => {
|
||||
const result = expectError(
|
||||
it('should allow missing password in key-value format (returns empty password)', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('host=localhost database=mydb user=admin'),
|
||||
);
|
||||
|
||||
expect(result.error).toContain('Password');
|
||||
expect(result.format).toBe('key-value');
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -351,12 +357,15 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.error).toContain('Username');
|
||||
});
|
||||
|
||||
it('should return error for missing password in URI', () => {
|
||||
const result = expectError(
|
||||
it('should allow missing password in URI (returns empty password)', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.error).toContain('Password');
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should return error for mysql:// format (wrong database type)', () => {
|
||||
@@ -446,4 +455,67 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Password Placeholder Handling', () => {
|
||||
it('should treat <db_password> placeholder as empty password in URI format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user:<db_password>@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should treat <password> placeholder as empty password in URI format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user:<password>@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should treat <db_password> placeholder as empty password in SRV format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'mongodb+srv://user:<db_password>@cluster0.mongodb.net/db',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('cluster0.mongodb.net');
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
|
||||
it('should treat <db_password> placeholder as empty password in key-value format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'host=localhost database=mydb user=admin password=<db_password>',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
|
||||
it('should treat <password> placeholder as empty password in key-value format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'host=localhost database=mydb user=admin password=<password>',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@ export type ParseResult = {
|
||||
database: string;
|
||||
authDatabase: string;
|
||||
useTls: boolean;
|
||||
isSrv: boolean;
|
||||
};
|
||||
|
||||
export type ParseError = {
|
||||
@@ -63,7 +64,8 @@ export class MongodbConnectionStringParser {
|
||||
const host = url.hostname;
|
||||
const port = url.port ? parseInt(url.port, 10) : isSrv ? 27017 : 27017;
|
||||
const username = decodeURIComponent(url.username);
|
||||
const password = decodeURIComponent(url.password);
|
||||
const rawPassword = decodeURIComponent(url.password);
|
||||
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword;
|
||||
const database = decodeURIComponent(url.pathname.slice(1));
|
||||
const authDatabase = this.getAuthSource(url.search) || 'admin';
|
||||
const useTls = isSrv ? true : this.checkTlsMode(url.search);
|
||||
@@ -76,10 +78,6 @@ export class MongodbConnectionStringParser {
|
||||
return { error: 'Username is missing from connection string' };
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
return { error: 'Password is missing from connection string' };
|
||||
}
|
||||
|
||||
return {
|
||||
host,
|
||||
port,
|
||||
@@ -88,6 +86,7 @@ export class MongodbConnectionStringParser {
|
||||
database: database || '',
|
||||
authDatabase,
|
||||
useTls,
|
||||
isSrv,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
@@ -114,7 +113,8 @@ export class MongodbConnectionStringParser {
|
||||
const port = params['port'];
|
||||
const database = params['database'] || params['dbname'] || params['db'];
|
||||
const username = params['user'] || params['username'];
|
||||
const password = params['password'];
|
||||
const rawPassword = params['password'];
|
||||
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword || '';
|
||||
const authDatabase = params['authSource'] || params['authDatabase'] || 'admin';
|
||||
const tls = params['tls'] || params['ssl'];
|
||||
|
||||
@@ -132,13 +132,6 @@ export class MongodbConnectionStringParser {
|
||||
};
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
return {
|
||||
error: 'Password is missing from connection string. Use password=yourpassword',
|
||||
format: 'key-value',
|
||||
};
|
||||
}
|
||||
|
||||
const useTls = this.isTlsEnabled(tls);
|
||||
|
||||
return {
|
||||
@@ -149,6 +142,7 @@ export class MongodbConnectionStringParser {
|
||||
database: database || '',
|
||||
authDatabase,
|
||||
useTls,
|
||||
isSrv: false,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
@@ -191,4 +185,11 @@ export class MongodbConnectionStringParser {
|
||||
const enabledValues = ['true', 'yes', '1'];
|
||||
return enabledValues.includes(lowercased);
|
||||
}
|
||||
|
||||
private static isPasswordPlaceholder(password: string | null | undefined): boolean {
|
||||
if (!password) return false;
|
||||
|
||||
const trimmed = password.trim();
|
||||
return trimmed === '<db_password>' || trimmed === '<password>';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,5 +10,6 @@ export interface MongodbDatabase {
|
||||
database: string;
|
||||
authDatabase: string;
|
||||
isHttps: boolean;
|
||||
isSrv: boolean;
|
||||
cpuCount: number;
|
||||
}
|
||||
|
||||
@@ -2,5 +2,4 @@ export interface GoogleDriveStorage {
|
||||
clientId: string;
|
||||
clientSecret: string;
|
||||
tokenJson?: string;
|
||||
useLocalRedirect?: boolean;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { Storage } from './Storage';
|
||||
|
||||
export interface StorageOauthDto {
|
||||
redirectUrl: string;
|
||||
storage: Storage;
|
||||
authCode: string;
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
const [isTestingConnection, setIsTestingConnection] = useState(false);
|
||||
const [isConnectionFailed, setIsConnectionFailed] = useState(false);
|
||||
|
||||
const hasAdvancedValues = !!database.mongodb?.authDatabase;
|
||||
const hasAdvancedValues = !!database.mongodb?.authDatabase || !!database.mongodb?.isSrv;
|
||||
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
|
||||
|
||||
const parseFromClipboard = async () => {
|
||||
@@ -75,17 +75,29 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
host: result.host,
|
||||
port: result.port,
|
||||
username: result.username,
|
||||
password: result.password,
|
||||
password: result.password || '',
|
||||
database: result.database,
|
||||
authDatabase: result.authDatabase,
|
||||
isHttps: result.useTls,
|
||||
isSrv: result.isSrv,
|
||||
cpuCount: 1,
|
||||
},
|
||||
};
|
||||
|
||||
if (result.isSrv) {
|
||||
setShowAdvanced(true);
|
||||
}
|
||||
|
||||
setEditingDatabase(updatedDatabase);
|
||||
setIsConnectionTested(false);
|
||||
message.success('Connection string parsed successfully');
|
||||
|
||||
if (!result.password) {
|
||||
message.warning(
|
||||
'Connection string parsed successfully. Please enter the password manually.',
|
||||
);
|
||||
} else {
|
||||
message.success('Connection string parsed successfully');
|
||||
}
|
||||
} catch {
|
||||
message.error('Failed to read clipboard. Please check browser permissions.');
|
||||
}
|
||||
@@ -156,9 +168,11 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
|
||||
if (!editingDatabase) return null;
|
||||
|
||||
const isSrvConnection = editingDatabase.mongodb?.isSrv || false;
|
||||
|
||||
let isAllFieldsFilled = true;
|
||||
if (!editingDatabase.mongodb?.host) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.port) isAllFieldsFilled = false;
|
||||
if (!isSrvConnection && !editingDatabase.mongodb?.port) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.username) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.id && !editingDatabase.mongodb?.password) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.database) isAllFieldsFilled = false;
|
||||
@@ -220,25 +234,27 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Port</div>
|
||||
<InputNumber
|
||||
type="number"
|
||||
value={editingDatabase.mongodb?.port}
|
||||
onChange={(e) => {
|
||||
if (!editingDatabase.mongodb || e === null) return;
|
||||
{!isSrvConnection && (
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Port</div>
|
||||
<InputNumber
|
||||
type="number"
|
||||
value={editingDatabase.mongodb?.port}
|
||||
onChange={(e) => {
|
||||
if (!editingDatabase.mongodb || e === null) return;
|
||||
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, port: e },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
className="max-w-[200px] grow"
|
||||
placeholder="27017"
|
||||
/>
|
||||
</div>
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, port: e },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
className="max-w-[200px] grow"
|
||||
placeholder="27017"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Username</div>
|
||||
@@ -366,6 +382,31 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
|
||||
{isShowAdvanced && (
|
||||
<>
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Use SRV connection</div>
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
checked={editingDatabase.mongodb?.isSrv || false}
|
||||
onChange={(checked) => {
|
||||
if (!editingDatabase.mongodb) return;
|
||||
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, isSrv: checked },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
/>
|
||||
<Tooltip
|
||||
className="cursor-pointer"
|
||||
title="Enable for MongoDB Atlas SRV connections (mongodb+srv://). Port is not required for SRV connections."
|
||||
>
|
||||
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Auth database</div>
|
||||
<Input
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import { DownOutlined, InfoCircleOutlined, UpOutlined } from '@ant-design/icons';
|
||||
import { Button, Checkbox, Input, Tooltip } from 'antd';
|
||||
import { useState } from 'react';
|
||||
import { Button, Input } from 'antd';
|
||||
|
||||
import { GOOGLE_DRIVE_OAUTH_REDIRECT_URL } from '../../../../../constants';
|
||||
import type { Storage } from '../../../../../entity/storages';
|
||||
import type { StorageOauthDto } from '../../../../../entity/storages/models/StorageOauthDto';
|
||||
|
||||
@@ -13,23 +10,16 @@ interface Props {
|
||||
}
|
||||
|
||||
export function EditGoogleDriveStorageComponent({ storage, setStorage, setUnsaved }: Props) {
|
||||
const hasAdvancedValues = !!storage?.googleDriveStorage?.useLocalRedirect;
|
||||
const [showAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
|
||||
const goToAuthUrl = () => {
|
||||
if (!storage?.googleDriveStorage?.clientId || !storage?.googleDriveStorage?.clientSecret) {
|
||||
return;
|
||||
}
|
||||
|
||||
const localRedirectUri = `${window.location.origin}/storages/google-oauth`;
|
||||
const useLocal = storage.googleDriveStorage.useLocalRedirect;
|
||||
const redirectUri = useLocal ? localRedirectUri : GOOGLE_DRIVE_OAUTH_REDIRECT_URL;
|
||||
|
||||
const redirectUri = `${window.location.origin}/storages/google-oauth`;
|
||||
const clientId = storage.googleDriveStorage.clientId;
|
||||
const scope = 'https://www.googleapis.com/auth/drive.file';
|
||||
const originUrl = `${window.location.origin}/storages/google-oauth`;
|
||||
|
||||
const oauthDto: StorageOauthDto = {
|
||||
redirectUrl: originUrl,
|
||||
storage: storage,
|
||||
authCode: '',
|
||||
};
|
||||
@@ -99,53 +89,6 @@ export function EditGoogleDriveStorageComponent({ storage, setStorage, setUnsave
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="mt-4 mb-3 flex items-center">
|
||||
<div
|
||||
className="flex cursor-pointer items-center text-sm text-blue-600 hover:text-blue-800"
|
||||
onClick={() => setShowAdvanced(!showAdvanced)}
|
||||
>
|
||||
<span className="mr-2">Advanced settings</span>
|
||||
|
||||
{showAdvanced ? (
|
||||
<UpOutlined style={{ fontSize: '12px' }} />
|
||||
) : (
|
||||
<DownOutlined style={{ fontSize: '12px' }} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{showAdvanced && (
|
||||
<div className="mb-4 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
<div className="flex items-center">
|
||||
<Checkbox
|
||||
checked={storage?.googleDriveStorage?.useLocalRedirect || false}
|
||||
onChange={(e) => {
|
||||
if (!storage?.googleDriveStorage) return;
|
||||
|
||||
setStorage({
|
||||
...storage,
|
||||
googleDriveStorage: {
|
||||
...storage.googleDriveStorage,
|
||||
useLocalRedirect: e.target.checked,
|
||||
},
|
||||
});
|
||||
setUnsaved();
|
||||
}}
|
||||
disabled={!!storage?.googleDriveStorage?.tokenJson}
|
||||
>
|
||||
<span>Use local redirect</span>
|
||||
</Checkbox>
|
||||
|
||||
<Tooltip
|
||||
className="cursor-pointer"
|
||||
title="When enabled, uses your address as the origin and redirect URL (specify it in Google Cloud Console). HTTPS is required."
|
||||
>
|
||||
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{storage?.googleDriveStorage?.tokenJson && (
|
||||
<>
|
||||
<div className="mb-1 flex w-full flex-col items-start sm:flex-row sm:items-center">
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { Modal, Spin } from 'antd';
|
||||
import { useEffect, useState } from 'react';
|
||||
|
||||
import { GOOGLE_DRIVE_OAUTH_REDIRECT_URL } from '../constants';
|
||||
import { type Storage, StorageType } from '../entity/storages';
|
||||
import type { StorageOauthDto } from '../entity/storages/models/StorageOauthDto';
|
||||
import type { UserProfile } from '../entity/users';
|
||||
@@ -21,7 +20,7 @@ export function OauthStorageComponent() {
|
||||
const { clientId, clientSecret } = oauthDto.storage.googleDriveStorage;
|
||||
const { authCode } = oauthDto;
|
||||
|
||||
const redirectUri = oauthDto.redirectUrl || GOOGLE_DRIVE_OAUTH_REDIRECT_URL;
|
||||
const redirectUri = `${window.location.origin}/storages/google-oauth`;
|
||||
|
||||
try {
|
||||
// Exchange authorization code for access token
|
||||
@@ -84,33 +83,14 @@ export function OauthStorageComponent() {
|
||||
});
|
||||
|
||||
const urlParams = new URLSearchParams(window.location.search);
|
||||
|
||||
// Attempt 1: Check for the 'oauthDto' param (Third-party/Legacy way)
|
||||
const oauthDtoParam = urlParams.get('oauthDto');
|
||||
if (oauthDtoParam) {
|
||||
try {
|
||||
const decodedParam = decodeURIComponent(oauthDtoParam);
|
||||
const oauthDto: StorageOauthDto = JSON.parse(decodedParam);
|
||||
processOauthDto(oauthDto);
|
||||
return;
|
||||
} catch (e) {
|
||||
console.error('Error parsing oauthDto parameter:', e);
|
||||
alert('Malformed OAuth parameter received');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt 2: Check for 'code' and 'state' (Direct Google/Local way)
|
||||
const code = urlParams.get('code');
|
||||
const state = urlParams.get('state');
|
||||
|
||||
if (code && state) {
|
||||
try {
|
||||
// The 'state' parameter contains our stringified StorageOauthDto
|
||||
const decodedState = decodeURIComponent(state);
|
||||
const oauthDto: StorageOauthDto = JSON.parse(decodedState);
|
||||
|
||||
// Inject the authorization code received from Google
|
||||
oauthDto.authCode = code;
|
||||
|
||||
processOauthDto(oauthDto);
|
||||
@@ -122,7 +102,6 @@ export function OauthStorageComponent() {
|
||||
}
|
||||
}
|
||||
|
||||
// Attempt 3: No valid parameters found
|
||||
alert('OAuth param not found. Ensure the redirect URL is configured correctly.');
|
||||
}, []);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user