diff --git a/backend/internal/features/backups/backups/controller_test.go b/backend/internal/features/backups/backups/controller_test.go index b919ae5..252fd41 100644 --- a/backend/internal/features/backups/backups/controller_test.go +++ b/backend/internal/features/backups/backups/controller_test.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/http" + "strconv" "strings" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" backups_config "databasus-backend/internal/features/backups/config" "databasus-backend/internal/features/databases" @@ -679,7 +681,13 @@ func createTestDatabase( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" + env := config.GetEnv() + port, err := strconv.Atoi(env.TestPostgres16Port) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err)) + } + + testDbName := "testdb" request := databases.Database{ Name: name, WorkspaceID: &workspaceID, @@ -687,9 +695,9 @@ func createTestDatabase( Postgresql: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", + Port: port, + Username: "testuser", + Password: "testpassword", Database: &testDbName, CpuCount: 1, }, diff --git a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go index 5315964..b14029f 100644 --- a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go @@ -107,12 +107,14 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs( "--user=" + mdb.Username, "--single-transaction", "--routines", - "--triggers", "--quick", "--verbose", } - if !mdb.IsExcludeEvents { + if mdb.HasPrivilege("TRIGGER") { + args = append(args, "--triggers") + } + if mdb.HasPrivilege("EVENT") { args = append(args, "--events") } diff --git a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go index 7eee4d7..50a0a48 100644 --- a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go @@ -105,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab "--user=" + my.Username, "--single-transaction", "--routines", - "--triggers", - "--events", "--set-gtid-purged=OFF", "--quick", "--verbose", } + if my.HasPrivilege("TRIGGER") { + args = append(args, "--triggers") + } + if my.HasPrivilege("EVENT") { + args = append(args, "--events") + } + args = append(args, uc.getNetworkCompressionArgs(my.Version)...) if my.IsHttps { diff --git a/backend/internal/features/backups/config/controller_test.go b/backend/internal/features/backups/config/controller_test.go index 15067ef..841fb28 100644 --- a/backend/internal/features/backups/config/controller_test.go +++ b/backend/internal/features/backups/config/controller_test.go @@ -2,13 +2,16 @@ package backups_config import ( "encoding/json" + "fmt" "net/http" + "strconv" "testing" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/assert" + "databasus-backend/internal/config" "databasus-backend/internal/features/databases" "databasus-backend/internal/features/databases/databases/postgresql" "databasus-backend/internal/features/intervals" @@ -1434,7 +1437,13 @@ func createTestDatabaseViaAPI( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" + env := config.GetEnv() + port, err := strconv.Atoi(env.TestPostgres16Port) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err)) + } + + testDbName := "testdb" request := databases.Database{ WorkspaceID: &workspaceID, Name: name, @@ -1442,9 +1451,9 @@ func createTestDatabaseViaAPI( Postgresql: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", + Port: port, + Username: "testuser", + Password: "testpassword", Database: &testDbName, CpuCount: 1, }, @@ -1459,7 +1468,9 @@ func createTestDatabaseViaAPI( ) if w.Code != http.StatusCreated { - panic("Failed to create database") + panic( + fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()), + ) } var database databases.Database diff --git a/backend/internal/features/databases/controller_test.go b/backend/internal/features/databases/controller_test.go index a43de9c..c6359e1 100644 --- a/backend/internal/features/databases/controller_test.go +++ b/backend/internal/features/databases/controller_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "strings" "testing" @@ -11,6 +12,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "databasus-backend/internal/config" "databasus-backend/internal/features/databases/databases/mariadb" "databasus-backend/internal/features/databases/databases/mongodb" "databasus-backend/internal/features/databases/databases/postgresql" @@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine { return router } +func getTestPostgresConfig() *postgresql.PostgresqlDatabase { + env := config.GetEnv() + port, err := strconv.Atoi(env.TestPostgres16Port) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err)) + } + + testDbName := "testdb" + return &postgresql.PostgresqlDatabase{ + Version: tools.PostgresqlVersion16, + Host: "localhost", + Port: port, + Username: "testuser", + Password: "testpassword", + Database: &testDbName, + CpuCount: 1, + } +} + +func getTestMariadbConfig() *mariadb.MariadbDatabase { + env := config.GetEnv() + portStr := env.TestMariadb1011Port + if portStr == "" { + portStr = "33111" + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err)) + } + + testDbName := "testdb" + return &mariadb.MariadbDatabase{ + Version: tools.MariadbVersion1011, + Host: "localhost", + Port: port, + Username: "testuser", + Password: "testpassword", + Database: &testDbName, + } +} + +func getTestMongodbConfig() *mongodb.MongodbDatabase { + env := config.GetEnv() + portStr := env.TestMongodb70Port + if portStr == "" { + portStr = "27070" + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err)) + } + + return &mongodb.MongodbDatabase{ + Version: tools.MongodbVersion7, + Host: "localhost", + Port: port, + Username: "root", + Password: "rootpassword", + Database: "testdb", + AuthDatabase: "admin", + IsHttps: false, + CpuCount: 1, + } +} + func Test_CreateDatabase_PermissionsEnforced(t *testing.T) { tests := []struct { name string @@ -88,20 +155,11 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) { testUserToken = member.Token } - testDbName := "test_db" request := Database{ Name: "Test Database", WorkspaceID: &workspace.ID, Type: DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: getTestPostgresConfig(), } var response Database @@ -132,20 +190,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember) - testDbName := "test_db" request := Database{ Name: "Test Database", WorkspaceID: &workspace.ID, Type: DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: getTestPostgresConfig(), } testResp := test_utils.MakePostRequest( @@ -737,7 +786,13 @@ func createTestDatabaseViaAPI( token string, router *gin.Engine, ) *Database { - testDbName := "test_db" + env := config.GetEnv() + port, err := strconv.Atoi(env.TestPostgres16Port) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err)) + } + + testDbName := "testdb" request := Database{ Name: name, WorkspaceID: &workspaceID, @@ -745,9 +800,9 @@ func createTestDatabaseViaAPI( Postgresql: &postgresql.PostgresqlDatabase{ Version: tools.PostgresqlVersion16, Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", + Port: port, + Username: "testuser", + Password: "testpassword", Database: &testDbName, CpuCount: 1, }, @@ -780,21 +835,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) { owner := users_testing.CreateTestUser(users_enums.UserRoleMember) workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) - testDbName := "test_db" - plainPassword := "my-super-secret-password-123" + pgConfig := getTestPostgresConfig() + plainPassword := "testpassword" + pgConfig.Password = plainPassword request := Database{ Name: "Test Database", WorkspaceID: &workspace.ID, Type: DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: plainPassword, - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: pgConfig, } var createdDatabase Database @@ -854,38 +902,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { name: "PostgreSQL Database", databaseType: DatabaseTypePostgres, createDatabase: func(workspaceID uuid.UUID) *Database { - testDbName := "test_db" + pgConfig := getTestPostgresConfig() return &Database{ WorkspaceID: &workspaceID, Name: "Test PostgreSQL Database", Type: DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "original-password-secret", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: pgConfig, } }, updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database { - testDbName := "updated_test_db" + pgConfig := getTestPostgresConfig() + pgConfig.Password = "" return &Database{ ID: databaseID, WorkspaceID: &workspaceID, Name: "Updated PostgreSQL Database", Type: DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion17, - Host: "updated-host", - Port: 5433, - Username: "updated_user", - Password: "", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: pgConfig, } }, verifySensitiveData: func(t *testing.T, database *Database) { @@ -895,7 +928,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { encryptor := encryption.GetFieldEncryptor() decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password) assert.NoError(t, err) - assert.Equal(t, "original-password-secret", decrypted) + assert.Equal(t, "testpassword", decrypted) }, verifyHiddenData: func(t *testing.T, database *Database) { assert.Equal(t, "", database.Postgresql.Password) @@ -905,36 +938,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { name: "MariaDB Database", databaseType: DatabaseTypeMariadb, createDatabase: func(workspaceID uuid.UUID) *Database { - testDbName := "test_db" + mariaConfig := getTestMariadbConfig() return &Database{ WorkspaceID: &workspaceID, Name: "Test MariaDB Database", Type: DatabaseTypeMariadb, - Mariadb: &mariadb.MariadbDatabase{ - Version: tools.MariadbVersion1011, - Host: "localhost", - Port: 3306, - Username: "root", - Password: "original-password-secret", - Database: &testDbName, - }, + Mariadb: mariaConfig, } }, updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database { - testDbName := "updated_test_db" + mariaConfig := getTestMariadbConfig() + mariaConfig.Password = "" return &Database{ ID: databaseID, WorkspaceID: &workspaceID, Name: "Updated MariaDB Database", Type: DatabaseTypeMariadb, - Mariadb: &mariadb.MariadbDatabase{ - Version: tools.MariadbVersion114, - Host: "updated-host", - Port: 3307, - Username: "updated_user", - Password: "", - Database: &testDbName, - }, + Mariadb: mariaConfig, } }, verifySensitiveData: func(t *testing.T, database *Database) { @@ -944,7 +964,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { encryptor := encryption.GetFieldEncryptor() decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password) assert.NoError(t, err) - assert.Equal(t, "original-password-secret", decrypted) + assert.Equal(t, "testpassword", decrypted) }, verifyHiddenData: func(t *testing.T, database *Database) { assert.Equal(t, "", database.Mariadb.Password) @@ -954,40 +974,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { name: "MongoDB Database", databaseType: DatabaseTypeMongodb, createDatabase: func(workspaceID uuid.UUID) *Database { + mongoConfig := getTestMongodbConfig() return &Database{ WorkspaceID: &workspaceID, Name: "Test MongoDB Database", Type: DatabaseTypeMongodb, - Mongodb: &mongodb.MongodbDatabase{ - Version: tools.MongodbVersion7, - Host: "localhost", - Port: 27017, - Username: "root", - Password: "original-password-secret", - Database: "test_db", - AuthDatabase: "admin", - IsHttps: false, - CpuCount: 1, - }, + Mongodb: mongoConfig, } }, updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database { + mongoConfig := getTestMongodbConfig() + mongoConfig.Password = "" return &Database{ ID: databaseID, WorkspaceID: &workspaceID, Name: "Updated MongoDB Database", Type: DatabaseTypeMongodb, - Mongodb: &mongodb.MongodbDatabase{ - Version: tools.MongodbVersion8, - Host: "updated-host", - Port: 27018, - Username: "updated_user", - Password: "", - Database: "updated_test_db", - AuthDatabase: "admin", - IsHttps: false, - CpuCount: 1, - }, + Mongodb: mongoConfig, } }, verifySensitiveData: func(t *testing.T, database *Database) { @@ -997,7 +1000,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) { encryptor := encryption.GetFieldEncryptor() decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password) assert.NoError(t, err) - assert.Equal(t, "original-password-secret", decrypted) + assert.Equal(t, "rootpassword", decrypted) }, verifyHiddenData: func(t *testing.T, database *Database) { assert.Equal(t, "", database.Mongodb.Password) diff --git a/backend/internal/features/databases/databases/mariadb/model.go b/backend/internal/features/databases/databases/mariadb/model.go index 732b8a4..cd85c0f 100644 --- a/backend/internal/features/databases/databases/mariadb/model.go +++ b/backend/internal/features/databases/databases/mariadb/model.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "regexp" + "sort" "strings" "time" @@ -23,15 +24,13 @@ type MariadbDatabase struct { Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"` - Host string `json:"host" gorm:"type:text;not null"` - Port int `json:"port" gorm:"type:int;not null"` - Username string `json:"username" gorm:"type:text;not null"` - Password string `json:"password" gorm:"type:text;not null"` - Database *string `json:"database" gorm:"type:text"` - IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"` - - // advanced - IsExcludeEvents bool `json:"isExcludeEvents" gorm:"column:is_exclude_events;type:boolean;default:false"` + Host string `json:"host" gorm:"type:text;not null"` + Port int `json:"port" gorm:"type:int;not null"` + Username string `json:"username" gorm:"type:text;not null"` + Password string `json:"password" gorm:"type:text;not null"` + Database *string `json:"database" gorm:"type:text"` + IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"` + Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"` } func (m *MariadbDatabase) TableName() string { @@ -97,7 +96,13 @@ func (m *MariadbDatabase) TestConnection( } m.Version = detectedVersion - if err := checkBackupPermissions(ctx, db, *m.Database); err != nil { + privileges, err := detectPrivileges(ctx, db, *m.Database) + if err != nil { + return err + } + m.Privileges = privileges + + if err := checkBackupPermissions(m.Privileges); err != nil { return err } @@ -118,7 +123,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) { m.Username = incoming.Username m.Database = incoming.Database m.IsHttps = incoming.IsHttps - m.IsExcludeEvents = incoming.IsExcludeEvents + m.Privileges = incoming.Privileges if incoming.Password != "" { m.Password = incoming.Password @@ -139,15 +144,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields( return nil } -func (m *MariadbDatabase) PopulateVersionIfEmpty( +func (m *MariadbDatabase) PopulateDbData( logger *slog.Logger, encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { - if m.Version != "" { + if m.Database == nil || *m.Database == "" { return nil } - return m.PopulateVersion(logger, encryptor, databaseID) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID) + if err != nil { + return fmt.Errorf("failed to decrypt password: %w", err) + } + + dsn := m.buildDSN(password, *m.Database) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + defer func() { + if closeErr := db.Close(); closeErr != nil { + logger.Error("Failed to close connection", "error", closeErr) + } + }() + + detectedVersion, err := detectMariadbVersion(ctx, db) + if err != nil { + return err + } + m.Version = detectedVersion + + privileges, err := detectPrivileges(ctx, db, *m.Database) + if err != nil { + return err + } + m.Privileges = privileges + + return nil } func (m *MariadbDatabase) PopulateVersion( @@ -183,8 +221,8 @@ func (m *MariadbDatabase) PopulateVersion( if err != nil { return err } - m.Version = detectedVersion + return nil } @@ -345,10 +383,23 @@ func (m *MariadbDatabase) CreateReadOnlyUser( return "", "", errors.New("failed to generate unique username after 3 attempts") } +func (m *MariadbDatabase) HasPrivilege(priv string) bool { + return HasPrivilege(m.Privileges, priv) +} + +func HasPrivilege(privileges, priv string) bool { + for _, p := range strings.Split(privileges, ",") { + if strings.TrimSpace(p) == priv { + return true + } + } + return false +} + func (m *MariadbDatabase) buildDSN(password string, database string) string { tlsConfig := "false" if m.IsHttps { - tlsConfig = "true" + tlsConfig = "skip-verify" } return fmt.Sprintf( @@ -439,22 +490,19 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) { } } -// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup. -// Required privileges: SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT on database; PROCESS globally. -func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) error { +// detectPrivileges detects backup-related privileges and returns them as comma-separated string +func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) { rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()") if err != nil { - return fmt.Errorf("failed to check grants: %w", err) + return "", fmt.Errorf("failed to check grants: %w", err) } defer func() { _ = rows.Close() }() - requiredDBPrivileges := map[string]bool{ - "SELECT": false, - "SHOW VIEW": false, - "LOCK TABLES": false, - "TRIGGER": false, - "EVENT": false, + backupPrivileges := []string{ + "SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT", } + + detectedPrivileges := make(map[string]bool) hasProcess := false hasAllPrivileges := false @@ -467,7 +515,7 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er for rows.Next() { var grant string if err := rows.Scan(&grant); err != nil { - return fmt.Errorf("failed to scan grant: %w", err) + return "", fmt.Errorf("failed to scan grant: %w", err) } if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) { @@ -477,9 +525,9 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er } if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) { - for priv := range requiredDBPrivileges { + for _, priv := range backupPrivileges { if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) { - requiredDBPrivileges[priv] = true + detectedPrivileges[priv] = true } } } @@ -491,26 +539,43 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er } if err := rows.Err(); err != nil { - return fmt.Errorf("error iterating grants: %w", err) + return "", fmt.Errorf("error iterating grants: %w", err) } if hasAllPrivileges { - return nil + for _, priv := range backupPrivileges { + detectedPrivileges[priv] = true + } + hasProcess = true } + privileges := make([]string, 0, len(detectedPrivileges)+1) + for priv := range detectedPrivileges { + privileges = append(privileges, priv) + } + if hasProcess { + privileges = append(privileges, "PROCESS") + } + + sort.Strings(privileges) + return strings.Join(privileges, ","), nil +} + +// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup. +// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT. +func checkBackupPermissions(privileges string) error { + requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"} + var missingPrivileges []string - for priv, has := range requiredDBPrivileges { - if !has { + for _, priv := range requiredPrivileges { + if !HasPrivilege(privileges, priv) { missingPrivileges = append(missingPrivileges, priv) } } - if !hasProcess { - missingPrivileges = append(missingPrivileges, "PROCESS (global)") - } if len(missingPrivileges) > 0 { return fmt.Errorf( - "insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT on database; PROCESS globally", + "insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS", strings.Join(missingPrivileges, ", "), ) } diff --git a/backend/internal/features/databases/databases/mongodb/model.go b/backend/internal/features/databases/databases/mongodb/model.go index f2a379b..c4f2244 100644 --- a/backend/internal/features/databases/databases/mongodb/model.go +++ b/backend/internal/features/databases/databases/mongodb/model.go @@ -140,14 +140,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields( return nil } -func (m *MongodbDatabase) PopulateVersionIfEmpty( +func (m *MongodbDatabase) PopulateDbData( logger *slog.Logger, encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { - if m.Version != "" { - return nil - } return m.PopulateVersion(logger, encryptor, databaseID) } @@ -447,20 +444,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string { authDB = "admin" } - tlsOption := "false" + tlsParams := "" if m.IsHttps { - tlsOption = "true" + tlsParams = "&tls=true&tlsInsecure=true" } return fmt.Sprintf( - "mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000", + "mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s", url.QueryEscape(m.Username), url.QueryEscape(password), m.Host, m.Port, m.Database, authDB, - tlsOption, + tlsParams, ) } @@ -471,19 +468,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string { authDB = "admin" } - tlsOption := "false" + tlsParams := "" if m.IsHttps { - tlsOption = "true" + tlsParams = "&tls=true&tlsInsecure=true" } return fmt.Sprintf( - "mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000", + "mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s", url.QueryEscape(m.Username), url.QueryEscape(password), m.Host, m.Port, authDB, - tlsOption, + tlsParams, ) } diff --git a/backend/internal/features/databases/databases/mysql/model.go b/backend/internal/features/databases/databases/mysql/model.go index a041b62..461a900 100644 --- a/backend/internal/features/databases/databases/mysql/model.go +++ b/backend/internal/features/databases/databases/mysql/model.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "regexp" + "sort" "strings" "time" @@ -23,12 +24,13 @@ type MysqlDatabase struct { Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"` - Host string `json:"host" gorm:"type:text;not null"` - Port int `json:"port" gorm:"type:int;not null"` - Username string `json:"username" gorm:"type:text;not null"` - Password string `json:"password" gorm:"type:text;not null"` - Database *string `json:"database" gorm:"type:text"` - IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"` + Host string `json:"host" gorm:"type:text;not null"` + Port int `json:"port" gorm:"type:int;not null"` + Username string `json:"username" gorm:"type:text;not null"` + Password string `json:"password" gorm:"type:text;not null"` + Database *string `json:"database" gorm:"type:text"` + IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"` + Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"` } func (m *MysqlDatabase) TableName() string { @@ -94,7 +96,13 @@ func (m *MysqlDatabase) TestConnection( } m.Version = detectedVersion - if err := checkBackupPermissions(ctx, db, *m.Database); err != nil { + privileges, err := detectPrivileges(ctx, db, *m.Database) + if err != nil { + return err + } + m.Privileges = privileges + + if err := checkBackupPermissions(m.Privileges); err != nil { return err } @@ -115,6 +123,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) { m.Username = incoming.Username m.Database = incoming.Database m.IsHttps = incoming.IsHttps + m.Privileges = incoming.Privileges if incoming.Password != "" { m.Password = incoming.Password @@ -135,15 +144,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields( return nil } -func (m *MysqlDatabase) PopulateVersionIfEmpty( +func (m *MysqlDatabase) PopulateDbData( logger *slog.Logger, encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { - if m.Version != "" { + if m.Database == nil || *m.Database == "" { return nil } - return m.PopulateVersion(logger, encryptor, databaseID) + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID) + if err != nil { + return fmt.Errorf("failed to decrypt password: %w", err) + } + + dsn := m.buildDSN(password, *m.Database) + + db, err := sql.Open("mysql", dsn) + if err != nil { + return fmt.Errorf("failed to connect to database: %w", err) + } + defer func() { + if closeErr := db.Close(); closeErr != nil { + logger.Error("Failed to close connection", "error", closeErr) + } + }() + + detectedVersion, err := detectMysqlVersion(ctx, db) + if err != nil { + return err + } + m.Version = detectedVersion + + privileges, err := detectPrivileges(ctx, db, *m.Database) + if err != nil { + return err + } + m.Privileges = privileges + + return nil } func (m *MysqlDatabase) PopulateVersion( @@ -179,8 +221,8 @@ func (m *MysqlDatabase) PopulateVersion( if err != nil { return err } - m.Version = detectedVersion + return nil } @@ -342,10 +384,23 @@ func (m *MysqlDatabase) CreateReadOnlyUser( return "", "", errors.New("failed to generate unique username after 3 attempts") } +func (m *MysqlDatabase) HasPrivilege(priv string) bool { + return HasPrivilege(m.Privileges, priv) +} + +func HasPrivilege(privileges, priv string) bool { + for p := range strings.SplitSeq(privileges, ",") { + if strings.TrimSpace(p) == priv { + return true + } + } + return false +} + func (m *MysqlDatabase) buildDSN(password string, database string) string { tlsConfig := "false" if m.IsHttps { - tlsConfig = "true" + tlsConfig = "skip-verify" } return fmt.Sprintf( @@ -405,22 +460,19 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion { } } -// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup. -// Required privileges: SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT on database; PROCESS globally. -func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) error { +// detectPrivileges detects backup-related privileges and returns them as comma-separated string +func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) { rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()") if err != nil { - return fmt.Errorf("failed to check grants: %w", err) + return "", fmt.Errorf("failed to check grants: %w", err) } defer func() { _ = rows.Close() }() - requiredDBPrivileges := map[string]bool{ - "SELECT": false, - "SHOW VIEW": false, - "LOCK TABLES": false, - "TRIGGER": false, - "EVENT": false, + backupPrivileges := []string{ + "SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT", } + + detectedPrivileges := make(map[string]bool) hasProcess := false hasAllPrivileges := false @@ -433,7 +485,7 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er for rows.Next() { var grant string if err := rows.Scan(&grant); err != nil { - return fmt.Errorf("failed to scan grant: %w", err) + return "", fmt.Errorf("failed to scan grant: %w", err) } if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) { @@ -443,9 +495,9 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er } if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) { - for priv := range requiredDBPrivileges { + for _, priv := range backupPrivileges { if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) { - requiredDBPrivileges[priv] = true + detectedPrivileges[priv] = true } } } @@ -457,26 +509,43 @@ func checkBackupPermissions(ctx context.Context, db *sql.DB, database string) er } if err := rows.Err(); err != nil { - return fmt.Errorf("error iterating grants: %w", err) + return "", fmt.Errorf("error iterating grants: %w", err) } if hasAllPrivileges { - return nil + for _, priv := range backupPrivileges { + detectedPrivileges[priv] = true + } + hasProcess = true } + privileges := make([]string, 0, len(detectedPrivileges)+1) + for priv := range detectedPrivileges { + privileges = append(privileges, priv) + } + if hasProcess { + privileges = append(privileges, "PROCESS") + } + + sort.Strings(privileges) + return strings.Join(privileges, ","), nil +} + +// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup. +// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT. +func checkBackupPermissions(privileges string) error { + requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"} + var missingPrivileges []string - for priv, has := range requiredDBPrivileges { - if !has { + for _, priv := range requiredPrivileges { + if !HasPrivilege(privileges, priv) { missingPrivileges = append(missingPrivileges, priv) } } - if !hasProcess { - missingPrivileges = append(missingPrivileges, "PROCESS (global)") - } if len(missingPrivileges) > 0 { return fmt.Errorf( - "insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT on database; PROCESS globally", + "insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS", strings.Join(missingPrivileges, ", "), ) } diff --git a/backend/internal/features/databases/databases/postgresql/model.go b/backend/internal/features/databases/databases/postgresql/model.go index 073c7f1..32581de 100644 --- a/backend/internal/features/databases/databases/postgresql/model.go +++ b/backend/internal/features/databases/databases/postgresql/model.go @@ -137,16 +137,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields( return nil } -// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set. +// PopulateDbData detects and sets the PostgreSQL version. // This should be called before encrypting sensitive fields. -func (p *PostgresqlDatabase) PopulateVersionIfEmpty( +func (p *PostgresqlDatabase) PopulateDbData( logger *slog.Logger, encryptor encryption.FieldEncryptor, databaseID uuid.UUID, ) error { - if p.Version != "" { - return nil - } return p.PopulateVersion(logger, encryptor, databaseID) } diff --git a/backend/internal/features/databases/model.go b/backend/internal/features/databases/model.go index bf420e9..fd00534 100644 --- a/backend/internal/features/databases/model.go +++ b/backend/internal/features/databases/model.go @@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e return nil } -func (d *Database) PopulateVersionIfEmpty( +func (d *Database) PopulateDbData( logger *slog.Logger, encryptor encryption.FieldEncryptor, ) error { if d.Postgresql != nil { - return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID) + return d.Postgresql.PopulateDbData(logger, encryptor, d.ID) } if d.Mysql != nil { - return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID) + return d.Mysql.PopulateDbData(logger, encryptor, d.ID) } if d.Mariadb != nil { - return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID) + return d.Mariadb.PopulateDbData(logger, encryptor, d.ID) } if d.Mongodb != nil { - return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID) + return d.Mongodb.PopulateDbData(logger, encryptor, d.ID) } return nil } diff --git a/backend/internal/features/databases/service.go b/backend/internal/features/databases/service.go index 35778e8..323c79c 100644 --- a/backend/internal/features/databases/service.go +++ b/backend/internal/features/databases/service.go @@ -82,8 +82,8 @@ func (s *DatabaseService) CreateDatabase( return nil, err } - if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil { - return nil, fmt.Errorf("failed to auto-detect database version: %w", err) + if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil { + return nil, fmt.Errorf("failed to auto-detect database data: %w", err) } if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil { @@ -149,8 +149,8 @@ func (s *DatabaseService) UpdateDatabase( return err } - if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil { - return fmt.Errorf("failed to auto-detect database version: %w", err) + if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil { + return fmt.Errorf("failed to auto-detect database data: %w", err) } if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil { diff --git a/backend/internal/features/databases/testing.go b/backend/internal/features/databases/testing.go index ee523bf..bab070d 100644 --- a/backend/internal/features/databases/testing.go +++ b/backend/internal/features/databases/testing.go @@ -1,6 +1,12 @@ package databases import ( + "fmt" + "strconv" + + "databasus-backend/internal/config" + "databasus-backend/internal/features/databases/databases/mariadb" + "databasus-backend/internal/features/databases/databases/mongodb" "databasus-backend/internal/features/databases/databases/postgresql" "databasus-backend/internal/features/notifiers" "databasus-backend/internal/features/storages" @@ -9,6 +15,71 @@ import ( "github.com/google/uuid" ) +func GetTestPostgresConfig() *postgresql.PostgresqlDatabase { + env := config.GetEnv() + port, err := strconv.Atoi(env.TestPostgres16Port) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err)) + } + + testDbName := "testdb" + return &postgresql.PostgresqlDatabase{ + Version: tools.PostgresqlVersion16, + Host: "localhost", + Port: port, + Username: "testuser", + Password: "testpassword", + Database: &testDbName, + CpuCount: 1, + } +} + +func GetTestMariadbConfig() *mariadb.MariadbDatabase { + env := config.GetEnv() + portStr := env.TestMariadb1011Port + if portStr == "" { + portStr = "33111" + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err)) + } + + testDbName := "testdb" + return &mariadb.MariadbDatabase{ + Version: tools.MariadbVersion1011, + Host: "localhost", + Port: port, + Username: "testuser", + Password: "testpassword", + Database: &testDbName, + } +} + +func GetTestMongodbConfig() *mongodb.MongodbDatabase { + env := config.GetEnv() + portStr := env.TestMongodb70Port + if portStr == "" { + portStr = "27070" + } + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err)) + } + + return &mongodb.MongodbDatabase{ + Version: tools.MongodbVersion7, + Host: "localhost", + Port: port, + Username: "root", + Password: "rootpassword", + Database: "testdb", + AuthDatabase: "admin", + IsHttps: false, + CpuCount: 1, + } +} + func CreateTestDatabase( workspaceID uuid.UUID, storage *storages.Storage, @@ -18,16 +89,7 @@ func CreateTestDatabase( WorkspaceID: &workspaceID, Name: "test " + uuid.New().String(), Type: DatabaseTypePostgres, - - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - CpuCount: 1, - }, - + Postgresql: GetTestPostgresConfig(), Notifiers: []notifiers.Notifier{ *notifier, }, diff --git a/backend/internal/features/healthcheck/attempt/controller_test.go b/backend/internal/features/healthcheck/attempt/controller_test.go index 393bb0c..ee78988 100644 --- a/backend/internal/features/healthcheck/attempt/controller_test.go +++ b/backend/internal/features/healthcheck/attempt/controller_test.go @@ -12,13 +12,11 @@ import ( "github.com/stretchr/testify/assert" "databasus-backend/internal/features/databases" - "databasus-backend/internal/features/databases/databases/postgresql" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_testing "databasus-backend/internal/features/workspaces/testing" test_utils "databasus-backend/internal/util/testing" - "databasus-backend/internal/util/tools" ) func createTestRouter() *gin.Engine { @@ -205,20 +203,11 @@ func createTestDatabaseViaAPI( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" request := databases.Database{ WorkspaceID: &workspaceID, Name: name, Type: databases.DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: databases.GetTestPostgresConfig(), } w := workspaces_testing.MakeAPIRequest( diff --git a/backend/internal/features/healthcheck/config/controller_test.go b/backend/internal/features/healthcheck/config/controller_test.go index b9be184..74593de 100644 --- a/backend/internal/features/healthcheck/config/controller_test.go +++ b/backend/internal/features/healthcheck/config/controller_test.go @@ -10,13 +10,11 @@ import ( "github.com/stretchr/testify/assert" "databasus-backend/internal/features/databases" - "databasus-backend/internal/features/databases/databases/postgresql" users_enums "databasus-backend/internal/features/users/enums" users_testing "databasus-backend/internal/features/users/testing" workspaces_controllers "databasus-backend/internal/features/workspaces/controllers" workspaces_testing "databasus-backend/internal/features/workspaces/testing" test_utils "databasus-backend/internal/util/testing" - "databasus-backend/internal/util/tools" ) func createTestRouter() *gin.Engine { @@ -293,20 +291,11 @@ func createTestDatabaseViaAPI( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" request := databases.Database{ WorkspaceID: &workspaceID, Name: name, Type: databases.DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: databases.GetTestPostgresConfig(), } w := workspaces_testing.MakeAPIRequest( diff --git a/backend/internal/features/restores/controller_test.go b/backend/internal/features/restores/controller_test.go index 2a55f6f..5df0743 100644 --- a/backend/internal/features/restores/controller_test.go +++ b/backend/internal/features/restores/controller_test.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/http" + "strconv" "strings" "testing" "time" @@ -15,6 +16,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "databasus-backend/internal/config" audit_logs "databasus-backend/internal/features/audit_logs" "databasus-backend/internal/features/backups/backups" backups_config "databasus-backend/internal/features/backups/config" @@ -390,20 +392,11 @@ func createTestDatabase( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" request := databases.Database{ WorkspaceID: &workspaceID, Name: name, Type: databases.DatabaseTypePostgres, - Postgresql: &postgresql.PostgresqlDatabase{ - Version: tools.PostgresqlVersion16, - Host: "localhost", - Port: 5432, - Username: "postgres", - Password: "postgres", - Database: &testDbName, - CpuCount: 1, - }, + Postgresql: databases.GetTestPostgresConfig(), } w := workspaces_testing.MakeAPIRequest( @@ -434,7 +427,18 @@ func createTestMySQLDatabase( token string, router *gin.Engine, ) *databases.Database { - testDbName := "test_db" + env := config.GetEnv() + portStr := env.TestMysql80Port + if portStr == "" { + portStr = "33080" + } + + port, err := strconv.Atoi(portStr) + if err != nil { + panic(fmt.Sprintf("Failed to parse TEST_MYSQL_80_PORT: %v", err)) + } + + testDbName := "testdb" request := databases.Database{ WorkspaceID: &workspaceID, Name: name, @@ -442,9 +446,9 @@ func createTestMySQLDatabase( Mysql: &mysql.MysqlDatabase{ Version: tools.MysqlVersion80, Host: "localhost", - Port: 3306, - Username: "root", - Password: "password", + Port: port, + Username: "testuser", + Password: "testpassword", Database: &testDbName, }, } diff --git a/backend/internal/features/restores/service.go b/backend/internal/features/restores/service.go index 19f1996..917c4a8 100644 --- a/backend/internal/features/restores/service.go +++ b/backend/internal/features/restores/service.go @@ -229,8 +229,8 @@ func (s *RestoreService) RestoreBackup( Mongodb: requestDTO.MongodbDatabase, } - if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil { - return fmt.Errorf("failed to auto-detect database version: %w", err) + if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil { + return fmt.Errorf("failed to auto-detect database data: %w", err) } isExcludeExtensions := false diff --git a/backend/internal/features/tests/mariadb_backup_restore_test.go b/backend/internal/features/tests/mariadb_backup_restore_test.go index 3a29741..f2adab5 100644 --- a/backend/internal/features/tests/mariadb_backup_restore_test.go +++ b/backend/internal/features/tests/mariadb_backup_restore_test.go @@ -149,101 +149,6 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin } } -func Test_BackupAndRestoreMariadb_WithExcludeEvents_RestoreIsSuccessful(t *testing.T) { - env := config.GetEnv() - container, err := connectToMariadbContainer(tools.MariadbVersion120, env.TestMariadb120Port) - if err != nil { - t.Skipf("Skipping MariaDB 12.0 IsExcludeEvents test: %v", err) - return - } - defer func() { - if container.DB != nil { - container.DB.Close() - } - }() - - setupMariadbTestData(t, container.DB) - - router := createTestRouter() - user := users_testing.CreateTestUser(users_enums.UserRoleMember) - workspace := workspaces_testing.CreateTestWorkspace( - "MariaDB ExcludeEvents Test Workspace", - user, - router, - ) - - storage := storages.CreateTestStorage(workspace.ID) - - database := createMariadbDatabaseWithExcludeEventsViaAPI( - t, router, "MariaDB ExcludeEvents Test Database", workspace.ID, - container.Host, container.Port, - container.Username, container.Password, container.Database, - container.Version, - true, - user.Token, - ) - - enableBackupsViaAPI( - t, router, database.ID, storage.ID, - backups_config.BackupEncryptionNone, user.Token, - ) - - createBackupViaAPI(t, router, database.ID, user.Token) - - backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute) - assert.Equal(t, backups.BackupStatusCompleted, backup.Status) - - newDBName := "restoreddb_mariadb_excludeevents" - _, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName)) - assert.NoError(t, err) - - _, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName)) - assert.NoError(t, err) - - newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true", - container.Username, container.Password, container.Host, container.Port, newDBName) - newDB, err := sqlx.Connect("mysql", newDSN) - assert.NoError(t, err) - defer newDB.Close() - - createMariadbRestoreViaAPI( - t, router, backup.ID, - container.Host, container.Port, - container.Username, container.Password, newDBName, - container.Version, - user.Token, - ) - - restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute) - assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status) - - var tableExists int - err = newDB.Get( - &tableExists, - "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'", - newDBName, - ) - assert.NoError(t, err) - assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database") - - verifyMariadbDataIntegrity(t, container.DB, newDB) - - err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String())) - if err != nil { - t.Logf("Warning: Failed to delete backup file: %v", err) - } - - test_utils.MakeDeleteRequest( - t, - router, - "/api/v1/databases/"+database.ID.String(), - "Bearer "+user.Token, - http.StatusNoContent, - ) - storages.RemoveTestStorage(storage.ID) - workspaces_testing.RemoveTestWorkspace(workspace, router) -} - func testMariadbBackupRestoreForVersion( t *testing.T, mariadbVersion tools.MariadbVersion, @@ -554,40 +459,18 @@ func createMariadbDatabaseViaAPI( database string, version tools.MariadbVersion, token string, -) *databases.Database { - return createMariadbDatabaseWithExcludeEventsViaAPI( - t, router, name, workspaceID, - host, port, username, password, database, - version, false, token, - ) -} - -func createMariadbDatabaseWithExcludeEventsViaAPI( - t *testing.T, - router *gin.Engine, - name string, - workspaceID uuid.UUID, - host string, - port int, - username string, - password string, - database string, - version tools.MariadbVersion, - isExcludeEvents bool, - token string, ) *databases.Database { request := databases.Database{ Name: name, WorkspaceID: &workspaceID, Type: databases.DatabaseTypeMariadb, Mariadb: &mariadbtypes.MariadbDatabase{ - Host: host, - Port: port, - Username: username, - Password: password, - Database: &database, - Version: version, - IsExcludeEvents: isExcludeEvents, + Host: host, + Port: port, + Username: username, + Password: password, + Database: &database, + Version: version, }, } diff --git a/backend/migrations/20260105165527_add_privileges_to_mysql_mariadb.sql b/backend/migrations/20260105165527_add_privileges_to_mysql_mariadb.sql new file mode 100644 index 0000000..8944ef0 --- /dev/null +++ b/backend/migrations/20260105165527_add_privileges_to_mysql_mariadb.sql @@ -0,0 +1,23 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE mysql_databases + ADD COLUMN privileges TEXT NOT NULL DEFAULT ''; + +ALTER TABLE mariadb_databases + ADD COLUMN privileges TEXT NOT NULL DEFAULT ''; + +ALTER TABLE mariadb_databases + DROP COLUMN is_exclude_events; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE mariadb_databases + ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE; + +ALTER TABLE mariadb_databases + DROP COLUMN privileges; + +ALTER TABLE mysql_databases + DROP COLUMN privileges; +-- +goose StatementEnd diff --git a/backend/migrations/20260105170956_add_cascade_delete_to_workspace_fks.sql b/backend/migrations/20260105170956_add_cascade_delete_to_workspace_fks.sql new file mode 100644 index 0000000..1ef7aec --- /dev/null +++ b/backend/migrations/20260105170956_add_cascade_delete_to_workspace_fks.sql @@ -0,0 +1,43 @@ +-- +goose Up +-- +goose StatementBegin + +ALTER TABLE notifiers + DROP CONSTRAINT fk_notifiers_workspace_id; + +ALTER TABLE notifiers + ADD CONSTRAINT fk_notifiers_workspace_id + FOREIGN KEY (workspace_id) + REFERENCES workspaces (id) + ON DELETE CASCADE; + +ALTER TABLE storages + DROP CONSTRAINT fk_storages_workspace_id; + +ALTER TABLE storages + ADD CONSTRAINT fk_storages_workspace_id + FOREIGN KEY (workspace_id) + REFERENCES workspaces (id) + ON DELETE CASCADE; + +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin + +ALTER TABLE notifiers + DROP CONSTRAINT fk_notifiers_workspace_id; + +ALTER TABLE notifiers + ADD CONSTRAINT fk_notifiers_workspace_id + FOREIGN KEY (workspace_id) + REFERENCES workspaces (id); + +ALTER TABLE storages + DROP CONSTRAINT fk_storages_workspace_id; + +ALTER TABLE storages + ADD CONSTRAINT fk_storages_workspace_id + FOREIGN KEY (workspace_id) + REFERENCES workspaces (id); + +-- +goose StatementEnd diff --git a/frontend/src/features/databases/ui/CreateDatabaseComponent.tsx b/frontend/src/features/databases/ui/CreateDatabaseComponent.tsx index b6835aa..2be4005 100644 --- a/frontend/src/features/databases/ui/CreateDatabaseComponent.tsx +++ b/frontend/src/features/databases/ui/CreateDatabaseComponent.tsx @@ -163,6 +163,10 @@ export const CreateDatabaseComponent = ({ workspaceId, onCreated, onClose }: Pro } if (step === 'notifiers') { + if (isCreating) { + return