mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -679,7 +681,13 @@ func createTestDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
@@ -687,9 +695,9 @@ func createTestDatabase(
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
|
||||
@@ -107,12 +107,14 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--user=" + mdb.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if !mdb.IsExcludeEvents {
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
|
||||
@@ -105,13 +105,18 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--user=" + my.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if my.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if my.HasPrivilege("EVENT") {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
|
||||
|
||||
if my.IsHttps {
|
||||
|
||||
@@ -2,13 +2,16 @@ package backups_config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/intervals"
|
||||
@@ -1434,7 +1437,13 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
@@ -1442,9 +1451,9 @@ func createTestDatabaseViaAPI(
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
@@ -1459,7 +1468,9 @@ func createTestDatabaseViaAPI(
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
|
||||
@@ -392,13 +392,13 @@ func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
|
||||
isReadOnly, privileges, err := c.databaseService.IsUserReadOnly(user, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
|
||||
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly, Privileges: privileges})
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -32,6 +34,71 @@ func createTestRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -88,20 +155,11 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
testDbName := "test_db"
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
var response Database
|
||||
@@ -132,20 +190,11 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
testDbName := "test_db"
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: getTestPostgresConfig(),
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
@@ -737,7 +786,13 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
@@ -745,9 +800,9 @@ func createTestDatabaseViaAPI(
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
@@ -780,21 +835,14 @@ func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
testDbName := "test_db"
|
||||
plainPassword := "my-super-secret-password-123"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
plainPassword := "testpassword"
|
||||
pgConfig.Password = plainPassword
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: plainPassword,
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
|
||||
var createdDatabase Database
|
||||
@@ -854,38 +902,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "PostgreSQL Database",
|
||||
databaseType: DatabaseTypePostgres,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
testDbName := "test_db"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test PostgreSQL Database",
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "original-password-secret",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
testDbName := "updated_test_db"
|
||||
pgConfig := getTestPostgresConfig()
|
||||
pgConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated PostgreSQL Database",
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion17,
|
||||
Host: "updated-host",
|
||||
Port: 5433,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: pgConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -895,7 +928,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "testpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Postgresql.Password)
|
||||
@@ -905,36 +938,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "MariaDB Database",
|
||||
databaseType: DatabaseTypeMariadb,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
testDbName := "test_db"
|
||||
mariaConfig := getTestMariadbConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "original-password-secret",
|
||||
Database: &testDbName,
|
||||
},
|
||||
Mariadb: mariaConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
testDbName := "updated_test_db"
|
||||
mariaConfig := getTestMariadbConfig()
|
||||
mariaConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion114,
|
||||
Host: "updated-host",
|
||||
Port: 3307,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: &testDbName,
|
||||
},
|
||||
Mariadb: mariaConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -944,7 +964,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "testpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Mariadb.Password)
|
||||
@@ -954,40 +974,23 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
name: "MongoDB Database",
|
||||
databaseType: DatabaseTypeMongodb,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
mongoConfig := getTestMongodbConfig()
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test MongoDB Database",
|
||||
Type: DatabaseTypeMongodb,
|
||||
Mongodb: &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: 27017,
|
||||
Username: "root",
|
||||
Password: "original-password-secret",
|
||||
Database: "test_db",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Mongodb: mongoConfig,
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
mongoConfig := getTestMongodbConfig()
|
||||
mongoConfig.Password = ""
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated MongoDB Database",
|
||||
Type: DatabaseTypeMongodb,
|
||||
Mongodb: &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion8,
|
||||
Host: "updated-host",
|
||||
Port: 27018,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: "updated_test_db",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Mongodb: mongoConfig,
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
@@ -997,7 +1000,7 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Mongodb.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
assert.Equal(t, "rootpassword", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Mongodb.Password)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -23,15 +24,13 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
|
||||
// advanced
|
||||
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"column:is_exclude_events;type:boolean;default:false"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -97,6 +96,16 @@ func (m *MariadbDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
if err := checkBackupPermissions(m.Privileges); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -114,7 +123,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsExcludeEvents = incoming.IsExcludeEvents
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
@@ -135,15 +144,48 @@ func (m *MariadbDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MariadbDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMariadbVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersion(
|
||||
@@ -179,8 +221,8 @@ func (m *MariadbDatabase) PopulateVersion(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -189,17 +231,17 @@ func (m *MariadbDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
@@ -209,33 +251,44 @@ func (m *MariadbDatabase) IsUserReadOnly(
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
|
||||
"ALTER ROUTINE", "CREATE USER",
|
||||
"CREATE TABLESPACE", "DELETE HISTORY", "REFERENCES",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
privileges := make([]string, 0, len(detectedPrivileges))
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
@@ -330,10 +383,23 @@ func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) HasPrivilege(priv string) bool {
|
||||
return HasPrivilege(m.Privileges, priv)
|
||||
}
|
||||
|
||||
func HasPrivilege(privileges, priv string) bool {
|
||||
for _, p := range strings.Split(privileges, ",") {
|
||||
if strings.TrimSpace(p) == priv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
tlsConfig = "skip-verify"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -424,6 +490,99 @@ func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
|
||||
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
backupPrivileges := []string{
|
||||
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
if hasAllPrivileges {
|
||||
for _, priv := range backupPrivileges {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
hasProcess = true
|
||||
}
|
||||
|
||||
privileges := make([]string, 0, len(detectedPrivileges)+1)
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
if hasProcess {
|
||||
privileges = append(privileges, "PROCESS")
|
||||
}
|
||||
|
||||
sort.Strings(privileges)
|
||||
return strings.Join(privileges, ","), nil
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mariadb-dump backup.
|
||||
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
if !HasPrivilege(privileges, priv) {
|
||||
missingPrivileges = append(missingPrivileges, priv)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -18,6 +18,171 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE permission_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, limitedUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE backup_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.DB, backupUsername)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mariadbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -49,13 +214,56 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Root user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MariadbDatabase{
|
||||
Version: mariadbModel.Version,
|
||||
Host: mariadbModel.Host,
|
||||
Port: mariadbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mariadbModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -127,9 +335,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
@@ -382,6 +596,5 @@ func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
|
||||
}
|
||||
|
||||
func dropUserSafe(db *sqlx.DB, username string) {
|
||||
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
|
||||
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -96,6 +97,10 @@ func (m *MongodbDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
if err := checkBackupPermissions(ctx, client, m.Username, m.Database, m.AuthDatabase); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -135,14 +140,11 @@ func (m *MongodbDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MongodbDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MongodbDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
@@ -186,10 +188,10 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
uri := m.buildConnectionURI(password)
|
||||
@@ -197,7 +199,7 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
clientOptions := options.Client().ApplyURI(uri)
|
||||
client, err := mongo.Connect(ctx, clientOptions)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if disconnectErr := client.Disconnect(ctx); disconnectErr != nil {
|
||||
@@ -219,44 +221,153 @@ func (m *MongodbDatabase) IsUserReadOnly(
|
||||
}},
|
||||
}).Decode(&result)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get user info: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
writeRoles := []string{
|
||||
"readWrite", "readWriteAnyDatabase", "dbAdmin", "dbAdminAnyDatabase",
|
||||
"userAdmin", "userAdminAnyDatabase", "clusterAdmin", "root",
|
||||
"dbOwner", "backup", "restore",
|
||||
writeRoles := map[string]bool{
|
||||
"readWrite": true,
|
||||
"readWriteAnyDatabase": true,
|
||||
"dbAdmin": true,
|
||||
"dbAdminAnyDatabase": true,
|
||||
"userAdmin": true,
|
||||
"userAdminAnyDatabase": true,
|
||||
"clusterAdmin": true,
|
||||
"clusterManager": true,
|
||||
"hostManager": true,
|
||||
"root": true,
|
||||
"dbOwner": true,
|
||||
"restore": true,
|
||||
"__system": true,
|
||||
}
|
||||
|
||||
// Roles that are read-only for our backup purposes
|
||||
// The "backup" role has insert/update on mms.backup collection but is needed for mongodump
|
||||
readOnlyRoles := map[string]bool{
|
||||
"read": true,
|
||||
"backup": true,
|
||||
}
|
||||
|
||||
writeActions := map[string]bool{
|
||||
"insert": true,
|
||||
"update": true,
|
||||
"remove": true,
|
||||
"createCollection": true,
|
||||
"dropCollection": true,
|
||||
"createIndex": true,
|
||||
"dropIndex": true,
|
||||
"convertToCapped": true,
|
||||
"dropDatabase": true,
|
||||
"renameCollection": true,
|
||||
"createUser": true,
|
||||
"dropUser": true,
|
||||
"updateUser": true,
|
||||
"grantRole": true,
|
||||
"revokeRole": true,
|
||||
"dropRole": true,
|
||||
"createRole": true,
|
||||
"updateRole": true,
|
||||
"enableSharding": true,
|
||||
"shardCollection": true,
|
||||
"addShard": true,
|
||||
"removeShard": true,
|
||||
"shutdown": true,
|
||||
"replSetReconfig": true,
|
||||
"replSetStateChange": true,
|
||||
}
|
||||
|
||||
var detectedRoles []string
|
||||
|
||||
users, ok := result["users"].(bson.A)
|
||||
if !ok || len(users) == 0 {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
user, ok := users[0].(bson.M)
|
||||
if !ok {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
roles, ok := user["roles"].(bson.A)
|
||||
if !ok {
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Collect all role names and check for write roles
|
||||
for _, roleDoc := range roles {
|
||||
role, ok := roleDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
roleName, _ := role["role"].(string)
|
||||
for _, writeRole := range writeRoles {
|
||||
if roleName == writeRole {
|
||||
return false, nil
|
||||
if roleName != "" {
|
||||
detectedRoles = append(detectedRoles, roleName)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if any detected role is a write role
|
||||
for _, roleName := range detectedRoles {
|
||||
if writeRoles[roleName] {
|
||||
return false, detectedRoles, nil
|
||||
}
|
||||
}
|
||||
|
||||
// If all roles are known read-only roles (read, backup), skip inherited privilege check
|
||||
allRolesReadOnly := true
|
||||
for _, roleName := range detectedRoles {
|
||||
if !readOnlyRoles[roleName] {
|
||||
allRolesReadOnly = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allRolesReadOnly && len(detectedRoles) > 0 {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Check inherited privileges for custom roles
|
||||
var privResult bson.M
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "usersInfo", Value: bson.D{
|
||||
{Key: "user", Value: m.Username},
|
||||
{Key: "db", Value: authDB},
|
||||
}},
|
||||
{Key: "showPrivileges", Value: true},
|
||||
}).Decode(&privResult)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to get user privileges: %w", err)
|
||||
}
|
||||
|
||||
privUsers, ok := privResult["users"].(bson.A)
|
||||
if !ok || len(privUsers) == 0 {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
privUser, ok := privUsers[0].(bson.M)
|
||||
if !ok {
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
// Check inheritedPrivileges for write actions
|
||||
inheritedPrivileges, ok := privUser["inheritedPrivileges"].(bson.A)
|
||||
if ok {
|
||||
for _, privDoc := range inheritedPrivileges {
|
||||
priv, ok := privDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
actions, ok := priv["actions"].(bson.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, action := range actions {
|
||||
actionStr, ok := action.(string)
|
||||
if ok && writeActions[actionStr] {
|
||||
return false, detectedRoles, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return true, detectedRoles, nil
|
||||
}
|
||||
|
||||
func (m *MongodbDatabase) CreateReadOnlyUser(
|
||||
@@ -333,20 +444,20 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
tlsOption := "false"
|
||||
tlsParams := ""
|
||||
if m.IsHttps {
|
||||
tlsOption = "true"
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&tls=%s&connectTimeoutMS=15000",
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsOption,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -357,19 +468,19 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
tlsOption := "false"
|
||||
tlsParams := ""
|
||||
if m.IsHttps {
|
||||
tlsOption = "true"
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&tls=%s&connectTimeoutMS=15000",
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
authDB,
|
||||
tlsOption,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -414,6 +525,128 @@ func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.Mong
|
||||
}
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mongodump backup.
|
||||
// Required: 'read' role on target database OR 'backup' role on admin OR 'readAnyDatabase' role.
|
||||
func checkBackupPermissions(
|
||||
ctx context.Context,
|
||||
client *mongo.Client,
|
||||
username, database, authDatabase string,
|
||||
) error {
|
||||
authDB := authDatabase
|
||||
if authDB == "" {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
adminDB := client.Database(authDB)
|
||||
var result bson.M
|
||||
err := adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "usersInfo", Value: bson.D{
|
||||
{Key: "user", Value: username},
|
||||
{Key: "db", Value: authDB},
|
||||
}},
|
||||
{Key: "showPrivileges", Value: true},
|
||||
}).Decode(&result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
|
||||
users, ok := result["users"].(bson.A)
|
||||
if !ok || len(users) == 0 {
|
||||
return errors.New("insufficient permissions for backup. User not found")
|
||||
}
|
||||
|
||||
user, ok := users[0].(bson.M)
|
||||
if !ok {
|
||||
return errors.New("insufficient permissions for backup. Could not parse user info")
|
||||
}
|
||||
|
||||
// Check roles for backup permissions
|
||||
roles, ok := user["roles"].(bson.A)
|
||||
if !ok {
|
||||
return errors.New("insufficient permissions for backup. No roles assigned")
|
||||
}
|
||||
|
||||
backupRoles := map[string]bool{
|
||||
"backup": true,
|
||||
"root": true,
|
||||
"readAnyDatabase": true,
|
||||
"dbOwner": true,
|
||||
"__system": true,
|
||||
"clusterAdmin": true,
|
||||
"readWriteAnyDatabase": true,
|
||||
}
|
||||
|
||||
var userRoles []string
|
||||
hasBackupRole := false
|
||||
hasReadOnTargetDB := false
|
||||
|
||||
for _, roleDoc := range roles {
|
||||
role, ok := roleDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
roleName, _ := role["role"].(string)
|
||||
roleDB, _ := role["db"].(string)
|
||||
|
||||
if roleName != "" {
|
||||
userRoles = append(userRoles, roleName)
|
||||
}
|
||||
|
||||
if backupRoles[roleName] {
|
||||
hasBackupRole = true
|
||||
}
|
||||
|
||||
if roleName == "read" && (roleDB == database || roleDB == "") {
|
||||
hasReadOnTargetDB = true
|
||||
}
|
||||
if roleName == "readWrite" && (roleDB == database || roleDB == "") {
|
||||
hasReadOnTargetDB = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasBackupRole || hasReadOnTargetDB {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check inherited privileges for 'find' action on target database
|
||||
inheritedPrivileges, ok := user["inheritedPrivileges"].(bson.A)
|
||||
if ok {
|
||||
for _, privDoc := range inheritedPrivileges {
|
||||
priv, ok := privDoc.(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
resource, ok := priv["resource"].(bson.M)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
resourceDB, _ := resource["db"].(string)
|
||||
resourceCluster, _ := resource["cluster"].(bool)
|
||||
|
||||
isTargetDB := resourceDB == database || resourceDB == "" || resourceCluster
|
||||
|
||||
actions, ok := priv["actions"].(bson.A)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, action := range actions {
|
||||
actionStr, ok := action.(string)
|
||||
if ok && actionStr == "find" && isTargetDB {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Current roles: %s. Required: 'read' role on database '%s' OR 'backup' role on admin OR 'readAnyDatabase' role",
|
||||
strings.Join(userRoles, ", "),
|
||||
database,
|
||||
)
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -20,6 +20,138 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MongodbVersion
|
||||
port string
|
||||
}{
|
||||
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
|
||||
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
|
||||
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
|
||||
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
|
||||
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
|
||||
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
|
||||
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("permission_test").Drop(ctx)
|
||||
_, err := db.Collection("permission_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
adminDB := container.Client.Database(container.AuthDatabase)
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "createUser", Value: limitedUsername},
|
||||
{Key: "pwd", Value: limitedPassword},
|
||||
{Key: "roles", Value: bson.A{}},
|
||||
}).Err()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mongodbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MongodbVersion
|
||||
port string
|
||||
}{
|
||||
{"MongoDB 4.0", tools.MongodbVersion4, env.TestMongodb40Port},
|
||||
{"MongoDB 4.2", tools.MongodbVersion4, env.TestMongodb42Port},
|
||||
{"MongoDB 4.4", tools.MongodbVersion4, env.TestMongodb44Port},
|
||||
{"MongoDB 5.0", tools.MongodbVersion5, env.TestMongodb50Port},
|
||||
{"MongoDB 6.0", tools.MongodbVersion6, env.TestMongodb60Port},
|
||||
{"MongoDB 7.0", tools.MongodbVersion7, env.TestMongodb70Port},
|
||||
{"MongoDB 8.2", tools.MongodbVersion8, env.TestMongodb82Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMongodbContainer(t, tc.port, tc.version)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("backup_test").Drop(ctx)
|
||||
_, err := db.Collection("backup_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
adminDB := container.Client.Database(container.AuthDatabase)
|
||||
err = adminDB.RunCommand(ctx, bson.D{
|
||||
{Key: "createUser", Value: backupUsername},
|
||||
{Key: "pwd", Value: backupPassword},
|
||||
{Key: "roles", Value: bson.A{
|
||||
bson.D{
|
||||
{Key: "role", Value: "read"},
|
||||
{Key: "db", Value: container.Database},
|
||||
},
|
||||
}},
|
||||
}).Err()
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mongodbModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -47,13 +179,52 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, roles, err := mongodbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, roles, "Root user should have roles")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMongodbContainer(t, env.TestMongodb70Port, tools.MongodbVersion7)
|
||||
defer container.Client.Disconnect(context.Background())
|
||||
|
||||
ctx := context.Background()
|
||||
db := container.Client.Database(container.Database)
|
||||
|
||||
_ = db.Collection("readonly_check_test").Drop(ctx)
|
||||
_, err := db.Collection("readonly_check_test").InsertOne(ctx, bson.M{"data": "test1"})
|
||||
assert.NoError(t, err)
|
||||
|
||||
mongodbModel := createMongodbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
username, password, err := mongodbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MongodbDatabase{
|
||||
Version: mongodbModel.Version,
|
||||
Host: mongodbModel.Host,
|
||||
Port: mongodbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mongodbModel.Database,
|
||||
AuthDatabase: mongodbModel.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
isReadOnly, roles, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.NotEmpty(t, roles, "Read-only user should have roles (read, backup)")
|
||||
|
||||
dropUserSafe(container.Client, username, container.AuthDatabase)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -272,6 +443,7 @@ func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"databasus-backend/internal/util/encryption"
|
||||
@@ -22,12 +24,13 @@ type MysqlDatabase struct {
|
||||
|
||||
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) TableName() string {
|
||||
@@ -93,6 +96,16 @@ func (m *MysqlDatabase) TestConnection(
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
if err := checkBackupPermissions(m.Privileges); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -110,6 +123,7 @@ func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
@@ -130,15 +144,48 @@ func (m *MysqlDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersionIfEmpty(
|
||||
func (m *MysqlDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMysqlVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
privileges, err := detectPrivileges(ctx, db, *m.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Privileges = privileges
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersion(
|
||||
@@ -174,8 +221,8 @@ func (m *MysqlDatabase) PopulateVersion(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -184,17 +231,17 @@ func (m *MysqlDatabase) IsUserReadOnly(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
@@ -204,33 +251,45 @@ func (m *MysqlDatabase) IsUserReadOnly(
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
"EXECUTE", "FILE", "RELOAD", "SHUTDOWN", "CREATE ROUTINE",
|
||||
"ALTER ROUTINE", "CREATE USER",
|
||||
"CREATE TABLESPACE", "REFERENCES",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
privileges := make([]string, 0, len(detectedPrivileges))
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
@@ -325,10 +384,23 @@ func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) HasPrivilege(priv string) bool {
|
||||
return HasPrivilege(m.Privileges, priv)
|
||||
}
|
||||
|
||||
func HasPrivilege(privileges, priv string) bool {
|
||||
for p := range strings.SplitSeq(privileges, ",") {
|
||||
if strings.TrimSpace(p) == priv {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
tlsConfig = "skip-verify"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
@@ -388,6 +460,99 @@ func mapMysql8xVersion(minor string) tools.MysqlVersion {
|
||||
}
|
||||
}
|
||||
|
||||
// detectPrivileges detects backup-related privileges and returns them as comma-separated string
|
||||
func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string, error) {
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
backupPrivileges := []string{
|
||||
"SELECT", "SHOW VIEW", "LOCK TABLES", "TRIGGER", "EVENT",
|
||||
}
|
||||
|
||||
detectedPrivileges := make(map[string]bool)
|
||||
hasProcess := false
|
||||
hasAllPrivileges := false
|
||||
|
||||
escapedDB := strings.ReplaceAll(database, "_", "\\_")
|
||||
dbPattern := regexp.MustCompile(
|
||||
fmt.Sprintf("(?i)ON\\s+[`'\"]?(%s|\\*)[`'\"]?\\.\\*", regexp.QuoteMeta(escapedDB)),
|
||||
)
|
||||
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\.\*`)
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return "", fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
if regexp.MustCompile(`(?i)\bALL\s+PRIVILEGES\b`).MatchString(grant) {
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
hasAllPrivileges = true
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) || dbPattern.MatchString(grant) {
|
||||
for _, priv := range backupPrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if globalPattern.MatchString(grant) &&
|
||||
regexp.MustCompile(`(?i)\bPROCESS\b`).MatchString(grant) {
|
||||
hasProcess = true
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
if hasAllPrivileges {
|
||||
for _, priv := range backupPrivileges {
|
||||
detectedPrivileges[priv] = true
|
||||
}
|
||||
hasProcess = true
|
||||
}
|
||||
|
||||
privileges := make([]string, 0, len(detectedPrivileges)+1)
|
||||
for priv := range detectedPrivileges {
|
||||
privileges = append(privileges, priv)
|
||||
}
|
||||
if hasProcess {
|
||||
privileges = append(privileges, "PROCESS")
|
||||
}
|
||||
|
||||
sort.Strings(privileges)
|
||||
return strings.Join(privileges, ","), nil
|
||||
}
|
||||
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for mysqldump backup.
|
||||
// Required: SELECT, SHOW VIEW, PROCESS. Optional: LOCK TABLES, TRIGGER, EVENT.
|
||||
func checkBackupPermissions(privileges string) error {
|
||||
requiredPrivileges := []string{"SELECT", "SHOW VIEW", "PROCESS"}
|
||||
|
||||
var missingPrivileges []string
|
||||
for _, priv := range requiredPrivileges {
|
||||
if !HasPrivilege(privileges, priv) {
|
||||
missingPrivileges = append(missingPrivileges, priv)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: SELECT, SHOW VIEW, PROCESS",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
|
||||
@@ -18,6 +18,165 @@ import (
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS permission_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE permission_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO permission_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", limitedUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS backup_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE backup_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO backup_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec("FLUSH PRIVILEGES")
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(
|
||||
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", backupUsername),
|
||||
)
|
||||
}()
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = mysqlModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -42,13 +201,57 @@ func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Root user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_check_test`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`CREATE TABLE readonly_check_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`INSERT INTO readonly_check_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &MysqlDatabase{
|
||||
Version: mysqlModel.Version,
|
||||
Host: mysqlModel.Host,
|
||||
Port: mysqlModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mysqlModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
@@ -109,9 +312,15 @@ func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
@@ -137,16 +137,13 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
|
||||
// PopulateDbData detects and sets the PostgreSQL version.
|
||||
// This should be called before encrypting sensitive fields.
|
||||
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
|
||||
func (p *PostgresqlDatabase) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if p.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return p.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
@@ -192,29 +189,33 @@ func (p *PostgresqlDatabase) PopulateVersion(
|
||||
// IsUserReadOnly checks if the database user has read-only privileges.
|
||||
//
|
||||
// This method performs a comprehensive security check by examining:
|
||||
// - Role-level attributes (superuser, createrole, createdb)
|
||||
// - Role-level attributes (superuser, createrole, createdb, bypassrls, replication)
|
||||
// - Database-level privileges (CREATE, TEMP)
|
||||
// - Schema-level privileges (CREATE on any non-system schema)
|
||||
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
|
||||
// - Function-level privileges (EXECUTE on SECURITY DEFINER functions)
|
||||
//
|
||||
// A user is considered read-only only if they have ZERO write privileges
|
||||
// across all three levels. This ensures the database user follows the
|
||||
// across all levels. This ensures the database user follows the
|
||||
// principle of least privilege for backup operations.
|
||||
//
|
||||
// Returns: (isReadOnly, detectedPrivileges, error)
|
||||
func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
connStr := buildConnectionStringForDB(p, *p.Database, password)
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
@@ -222,22 +223,38 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
}
|
||||
}()
|
||||
|
||||
var privileges []string
|
||||
|
||||
// LEVEL 1: Check role-level attributes
|
||||
var isSuperuser, canCreateRole, canCreateDB bool
|
||||
var isSuperuser, canCreateRole, canCreateDB, canBypassRLS, canReplication bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT
|
||||
rolsuper,
|
||||
rolcreaterole,
|
||||
rolcreatedb
|
||||
rolcreatedb,
|
||||
rolbypassrls,
|
||||
rolreplication
|
||||
FROM pg_roles
|
||||
WHERE rolname = current_user
|
||||
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
|
||||
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB, &canBypassRLS, &canReplication)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check role attributes: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check role attributes: %w", err)
|
||||
}
|
||||
|
||||
if isSuperuser || canCreateRole || canCreateDB {
|
||||
return false, nil
|
||||
if isSuperuser {
|
||||
privileges = append(privileges, "SUPERUSER")
|
||||
}
|
||||
if canCreateRole {
|
||||
privileges = append(privileges, "CREATEROLE")
|
||||
}
|
||||
if canCreateDB {
|
||||
privileges = append(privileges, "CREATEDB")
|
||||
}
|
||||
if canBypassRLS {
|
||||
privileges = append(privileges, "BYPASSRLS")
|
||||
}
|
||||
if canReplication {
|
||||
privileges = append(privileges, "REPLICATION")
|
||||
}
|
||||
|
||||
// LEVEL 2: Check database-level privileges
|
||||
@@ -248,46 +265,34 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
|
||||
`).Scan(&canCreate, &canTemp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check database privileges: %w", err)
|
||||
}
|
||||
|
||||
if canCreate || canTemp {
|
||||
return false, nil
|
||||
if canCreate {
|
||||
privileges = append(privileges, "CREATE (database)")
|
||||
}
|
||||
if canTemp {
|
||||
privileges = append(privileges, "TEMP")
|
||||
}
|
||||
|
||||
// LEVEL 2.5: Check schema-level CREATE privileges
|
||||
schemaRows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT nspname
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
|
||||
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
`)
|
||||
var hasSchemaCreate bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
|
||||
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
)
|
||||
`).Scan(&hasSchemaCreate)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check schema privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("failed to check schema privileges: %w", err)
|
||||
}
|
||||
defer schemaRows.Close()
|
||||
|
||||
// If user has CREATE privilege on any schema, they're not read-only
|
||||
if schemaRows.Next() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := schemaRows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating schema privileges: %w", err)
|
||||
if hasSchemaCreate {
|
||||
privileges = append(privileges, "CREATE (schema)")
|
||||
}
|
||||
|
||||
// LEVEL 3: Check table-level write permissions
|
||||
rows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = current_user
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check table privileges: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
writePrivileges := map[string]bool{
|
||||
"INSERT": true,
|
||||
"UPDATE": true,
|
||||
@@ -297,22 +302,56 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
"TRIGGER": true,
|
||||
}
|
||||
|
||||
var tablePrivileges []string
|
||||
rows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = current_user
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to check table privileges: %w", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var privilege string
|
||||
if err := rows.Scan(&privilege); err != nil {
|
||||
return false, fmt.Errorf("failed to scan privilege: %w", err)
|
||||
}
|
||||
|
||||
if writePrivileges[privilege] {
|
||||
return false, nil
|
||||
rows.Close()
|
||||
return false, nil, fmt.Errorf("failed to scan privilege: %w", err)
|
||||
}
|
||||
tablePrivileges = append(tablePrivileges, privilege)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating privileges: %w", err)
|
||||
return false, nil, fmt.Errorf("error iterating privileges: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
for _, privilege := range tablePrivileges {
|
||||
if writePrivileges[privilege] {
|
||||
privileges = append(privileges, privilege)
|
||||
}
|
||||
}
|
||||
|
||||
// LEVEL 4: Check for EXECUTE privilege on functions that are SECURITY DEFINER
|
||||
var funcCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_proc p
|
||||
JOIN pg_namespace n ON p.pronamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND p.prosecdef = true
|
||||
AND has_function_privilege(current_user, p.oid, 'EXECUTE')
|
||||
`).Scan(&funcCount)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("failed to check function privileges: %w", err)
|
||||
}
|
||||
if funcCount > 0 {
|
||||
privileges = append(privileges, "EXECUTE (SECURITY DEFINER)")
|
||||
}
|
||||
|
||||
isReadOnly := len(privileges) == 0
|
||||
return isReadOnly, privileges, nil
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
|
||||
@@ -631,13 +670,9 @@ func testSingleDatabaseConnection(
|
||||
}
|
||||
postgresDb.Version = detectedVersion
|
||||
|
||||
// Test if we can perform basic operations (like pg_dump would need)
|
||||
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
|
||||
return fmt.Errorf(
|
||||
"basic operations test failed for database '%s': %w",
|
||||
*postgresDb.Database,
|
||||
err,
|
||||
)
|
||||
// Verify user has sufficient permissions for backup operations
|
||||
if err := checkBackupPermissions(ctx, conn, *postgresDb.Database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -670,18 +705,73 @@ func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.Postgresq
|
||||
}
|
||||
}
|
||||
|
||||
// testBasicOperations tests basic operations that backup tools need
|
||||
func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) error {
|
||||
var hasCreatePriv bool
|
||||
// checkBackupPermissions verifies the user has sufficient privileges for pg_dump backup.
|
||||
// Required privileges: CONNECT on database, USAGE on schemas, SELECT on tables.
|
||||
func checkBackupPermissions(ctx context.Context, conn *pgx.Conn, dbName string) error {
|
||||
var missingPrivileges []string
|
||||
|
||||
// Check CONNECT privilege on database
|
||||
var hasConnect bool
|
||||
err := conn.QueryRow(ctx, "SELECT has_database_privilege(current_user, current_database(), 'CONNECT')").
|
||||
Scan(&hasCreatePriv)
|
||||
Scan(&hasConnect)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check database privileges: %w", err)
|
||||
}
|
||||
if !hasConnect {
|
||||
missingPrivileges = append(missingPrivileges, "CONNECT on database")
|
||||
}
|
||||
|
||||
if !hasCreatePriv {
|
||||
return fmt.Errorf("user does not have CONNECT privilege on database '%s'", dbName)
|
||||
// Check USAGE privilege on at least one non-system schema
|
||||
var schemaCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'USAGE')
|
||||
AND n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname NOT LIKE 'pg_temp_%'
|
||||
AND n.nspname NOT LIKE 'pg_toast_temp_%'
|
||||
`).Scan(&schemaCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check schema privileges: %w", err)
|
||||
}
|
||||
if schemaCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "USAGE on at least one schema")
|
||||
}
|
||||
|
||||
// Check SELECT privilege on at least one table (if tables exist)
|
||||
// Use pg_tables from pg_catalog which shows all tables regardless of user privileges
|
||||
var tableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
`).Scan(&tableCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check table count: %w", err)
|
||||
}
|
||||
|
||||
if tableCount > 0 {
|
||||
// Check if user has SELECT on at least one of these tables
|
||||
var selectableTableCount int
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT COUNT(*)
|
||||
FROM pg_catalog.pg_tables t
|
||||
WHERE t.schemaname NOT IN ('pg_catalog', 'information_schema')
|
||||
AND has_table_privilege(current_user, quote_ident(t.schemaname) || '.' || quote_ident(t.tablename), 'SELECT')
|
||||
`).Scan(&selectableTableCount)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot check SELECT privileges: %w", err)
|
||||
}
|
||||
if selectableTableCount == 0 {
|
||||
missingPrivileges = append(missingPrivileges, "SELECT on tables")
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingPrivileges) > 0 {
|
||||
return fmt.Errorf(
|
||||
"insufficient permissions for backup. Missing: %s. Required: CONNECT on database, USAGE on schemas, SELECT on tables",
|
||||
strings.Join(missingPrivileges, ", "),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/tools"
|
||||
@@ -18,13 +21,23 @@ import (
|
||||
|
||||
func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToTestPostgresContainer(t, env.TestPostgres16Port)
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
passwordWithSpaces := "test password with spaces"
|
||||
usernameWithSpaces := fmt.Sprintf("testuser_spaces_%s", uuid.New().String()[:8])
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS password_test CASCADE;
|
||||
CREATE TABLE password_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO password_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
usernameWithSpaces,
|
||||
passwordWithSpaces,
|
||||
@@ -38,6 +51,18 @@ func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
usernameWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
|
||||
usernameWithSpaces,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, usernameWithSpaces))
|
||||
}()
|
||||
@@ -59,7 +84,628 @@ func Test_TestConnection_PasswordContainingSpaces_TestedSuccessfully(t *testing.
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type testPostgresContainer struct {
|
||||
func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS permission_test CASCADE;
|
||||
CREATE TABLE permission_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO permission_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
limitedUsername := fmt.Sprintf("limited_user_%s", uuid.New().String()[:8])
|
||||
limitedPassword := "limitedpassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
limitedUsername,
|
||||
limitedPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
limitedUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedUsername))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = pgModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.Error(t, err)
|
||||
if err != nil {
|
||||
assert.Contains(t, err.Error(), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS backup_test CASCADE;
|
||||
CREATE TABLE backup_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO backup_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backupUsername := fmt.Sprintf("backup_user_%s", uuid.New().String()[:8])
|
||||
backupPassword := "backuppassword123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
backupUsername,
|
||||
backupPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA public TO "%s"`,
|
||||
backupUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, backupUsername))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum(tc.version),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
|
||||
err = pgModel.TestConnection(logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, privileges, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Admin user should not be read-only")
|
||||
assert.NotEmpty(t, privileges, "Admin user should have privileges")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_IsUserReadOnly_ReadOnlyUser_ReturnsTrue(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS readonly_check_test CASCADE;
|
||||
CREATE TABLE readonly_check_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO readonly_check_test (data) VALUES ('test1');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Read-only user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS readonly_test CASCADE;
|
||||
DROP TABLE IF EXISTS hack_table CASCADE;
|
||||
DROP TABLE IF EXISTS future_table CASCADE;
|
||||
CREATE TABLE readonly_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
assert.Empty(t, privileges, "Read-only user should have no write privileges")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO future_table (data) VALUES ('future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS schema_a CASCADE;
|
||||
DROP SCHEMA IF EXISTS schema_b CASCADE;
|
||||
CREATE SCHEMA schema_a;
|
||||
CREATE SCHEMA schema_b;
|
||||
CREATE TABLE schema_a.table_a (id INT, data TEXT);
|
||||
CREATE TABLE schema_b.table_b (id INT, data TEXT);
|
||||
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
|
||||
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var dataA string
|
||||
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_a", dataA)
|
||||
|
||||
var dataB string
|
||||
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_b", dataB)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, container.Username, container.Password, dashDbName)
|
||||
dashDB, err := sqlx.Connect("postgres", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum("16"),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
t.Skip("Skipping Supabase test: missing environment variables")
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(env.TestSupabasePort)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
env.TestSupabaseUsername,
|
||||
env.TestSupabasePassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
|
||||
adminDB, err := sqlx.Connect("postgres", dsn)
|
||||
require.NoError(t, err)
|
||||
defer adminDB.Close()
|
||||
|
||||
tableName := fmt.Sprintf(
|
||||
"readonly_test_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = adminDB.Exec(fmt.Sprintf(`
|
||||
DROP TABLE IF EXISTS public.%s CASCADE;
|
||||
CREATE TABLE public.%s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
|
||||
`, tableName, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: env.TestSupabaseHost,
|
||||
Port: portInt,
|
||||
Username: env.TestSupabaseUsername,
|
||||
Password: env.TestSupabasePassword,
|
||||
Database: &env.TestSupabaseDatabase,
|
||||
IsHttps: true,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, connectionUsername)
|
||||
assert.NotEmpty(t, newPassword)
|
||||
assert.True(t, strings.HasPrefix(connectionUsername, "databasus-"))
|
||||
|
||||
baseUsername := connectionUsername
|
||||
if idx := strings.Index(connectionUsername, "."); idx != -1 {
|
||||
baseUsername = connectionUsername[:idx]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
|
||||
}()
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
connectionUsername,
|
||||
newPassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -68,7 +714,7 @@ type testPostgresContainer struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresContainer {
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
@@ -83,7 +729,11 @@ func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresCont
|
||||
db, err := sqlx.Connect("postgres", dsn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
return &testPostgresContainer{
|
||||
var versionStr string
|
||||
err = db.Get(&versionStr, "SELECT version()")
|
||||
assert.NoError(t, err)
|
||||
|
||||
return &PostgresContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
@@ -92,3 +742,42 @@ func connectToTestPostgresContainer(t *testing.T, port string) *testPostgresCont
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
|
||||
var versionStr string
|
||||
err := container.DB.Get(&versionStr, "SELECT version()")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
version := extractPostgresVersion(versionStr)
|
||||
|
||||
return &PostgresqlDatabase{
|
||||
Version: version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
|
||||
if strings.Contains(versionStr, "PostgreSQL 12") {
|
||||
return tools.GetPostgresqlVersionEnum("12")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 13") {
|
||||
return tools.GetPostgresqlVersionEnum("13")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 14") {
|
||||
return tools.GetPostgresqlVersionEnum("14")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 15") {
|
||||
return tools.GetPostgresqlVersionEnum("15")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 16") {
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 17") {
|
||||
return tools.GetPostgresqlVersionEnum("17")
|
||||
}
|
||||
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
}
|
||||
|
||||
@@ -1,508 +0,0 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Admin user should not be read-only")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS readonly_test CASCADE;
|
||||
DROP TABLE IF EXISTS hack_table CASCADE;
|
||||
DROP TABLE IF EXISTS future_table CASCADE;
|
||||
CREATE TABLE readonly_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO future_table (data) VALUES ('future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS schema_a CASCADE;
|
||||
DROP SCHEMA IF EXISTS schema_b CASCADE;
|
||||
CREATE SCHEMA schema_a;
|
||||
CREATE SCHEMA schema_b;
|
||||
CREATE TABLE schema_a.table_a (id INT, data TEXT);
|
||||
CREATE TABLE schema_b.table_b (id INT, data TEXT);
|
||||
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
|
||||
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var dataA string
|
||||
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_a", dataA)
|
||||
|
||||
var dataB string
|
||||
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_b", dataB)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, container.Username, container.Password, dashDbName)
|
||||
dashDB, err := sqlx.Connect("postgres", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum("16"),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "databasus-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
t.Skip("Skipping Supabase test: missing environment variables")
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(env.TestSupabasePort)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
env.TestSupabaseUsername,
|
||||
env.TestSupabasePassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
|
||||
adminDB, err := sqlx.Connect("postgres", dsn)
|
||||
require.NoError(t, err)
|
||||
defer adminDB.Close()
|
||||
|
||||
tableName := fmt.Sprintf(
|
||||
"readonly_test_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = adminDB.Exec(fmt.Sprintf(`
|
||||
DROP TABLE IF EXISTS public.%s CASCADE;
|
||||
CREATE TABLE public.%s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
|
||||
`, tableName, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: env.TestSupabaseHost,
|
||||
Port: portInt,
|
||||
Username: env.TestSupabaseUsername,
|
||||
Password: env.TestSupabasePassword,
|
||||
Database: &env.TestSupabaseDatabase,
|
||||
IsHttps: true,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, connectionUsername)
|
||||
assert.NotEmpty(t, newPassword)
|
||||
assert.True(t, strings.HasPrefix(connectionUsername, "databasus-"))
|
||||
|
||||
baseUsername := connectionUsername
|
||||
if idx := strings.Index(connectionUsername, "."); idx != -1 {
|
||||
baseUsername = connectionUsername[:idx]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
|
||||
}()
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
connectionUsername,
|
||||
newPassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
host, portInt, username, password, dbName)
|
||||
|
||||
db, err := sqlx.Connect("postgres", dsn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var versionStr string
|
||||
err = db.Get(&versionStr, "SELECT version()")
|
||||
assert.NoError(t, err)
|
||||
|
||||
return &PostgresContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
|
||||
var versionStr string
|
||||
err := container.DB.Get(&versionStr, "SELECT version()")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
version := extractPostgresVersion(versionStr)
|
||||
|
||||
return &PostgresqlDatabase{
|
||||
Version: version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
|
||||
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
|
||||
if strings.Contains(versionStr, "PostgreSQL 12") {
|
||||
return tools.GetPostgresqlVersionEnum("12")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 13") {
|
||||
return tools.GetPostgresqlVersionEnum("13")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 14") {
|
||||
return tools.GetPostgresqlVersionEnum("14")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 15") {
|
||||
return tools.GetPostgresqlVersionEnum("15")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 16") {
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 17") {
|
||||
return tools.GetPostgresqlVersionEnum("17")
|
||||
}
|
||||
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
}
|
||||
@@ -6,5 +6,6 @@ type CreateReadOnlyUserResponse struct {
|
||||
}
|
||||
|
||||
type IsReadOnlyResponse struct {
|
||||
IsReadOnly bool `json:"isReadOnly"`
|
||||
IsReadOnly bool `json:"isReadOnly"`
|
||||
Privileges []string `json:"privileges"`
|
||||
}
|
||||
|
||||
@@ -104,21 +104,21 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) PopulateVersionIfEmpty(
|
||||
func (d *Database) PopulateDbData(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Postgresql.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mysql != nil {
|
||||
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mysql.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mariadb != nil {
|
||||
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mariadb.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mongodb != nil {
|
||||
return d.Mongodb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
return d.Mongodb.PopulateDbData(logger, encryptor, d.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -82,8 +82,8 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
if err := database.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
@@ -149,8 +149,8 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
if err := existingDatabase.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
@@ -594,17 +594,17 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
|
||||
func (s *DatabaseService) IsUserReadOnly(
|
||||
user *users_models.User,
|
||||
database *Database,
|
||||
) (bool, error) {
|
||||
) (bool, []string, error) {
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return false, errors.New("cannot check user for database without workspace")
|
||||
return false, nil, errors.New("cannot check user for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
@@ -612,20 +612,20 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this database")
|
||||
return false, nil, errors.New("insufficient permissions to access this database")
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return false, errors.New("database does not belong to this workspace")
|
||||
return false, nil, errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
@@ -633,10 +633,10 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
if database.WorkspaceID != nil {
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this workspace")
|
||||
return false, nil, errors.New("insufficient permissions to access this workspace")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -676,7 +676,7 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return false, errors.New("read-only check not supported for this database type")
|
||||
return false, nil, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
"databasus-backend/internal/features/databases/databases/mariadb"
|
||||
"databasus-backend/internal/features/databases/databases/mongodb"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
@@ -9,6 +15,71 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
|
||||
env := config.GetEnv()
|
||||
port, err := strconv.Atoi(env.TestPostgres16Port)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func GetTestMariadbConfig() *mariadb.MariadbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMariadb1011Port
|
||||
if portStr == "" {
|
||||
portStr = "33111"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
return &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
}
|
||||
}
|
||||
|
||||
func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMongodb70Port
|
||||
if portStr == "" {
|
||||
portStr = "27070"
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
|
||||
}
|
||||
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: "localhost",
|
||||
Port: port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
func CreateTestDatabase(
|
||||
workspaceID uuid.UUID,
|
||||
storage *storages.Storage,
|
||||
@@ -18,16 +89,7 @@ func CreateTestDatabase(
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
CpuCount: 1,
|
||||
},
|
||||
|
||||
Postgresql: GetTestPostgresConfig(),
|
||||
Notifiers: []notifiers.Notifier{
|
||||
*notifier,
|
||||
},
|
||||
|
||||
@@ -12,13 +12,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -205,20 +203,11 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: databases.GetTestPostgresConfig(),
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
|
||||
@@ -10,13 +10,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
@@ -293,20 +291,11 @@ func createTestDatabaseViaAPI(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: databases.GetTestPostgresConfig(),
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
audit_logs "databasus-backend/internal/features/audit_logs"
|
||||
"databasus-backend/internal/features/backups/backups"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
@@ -390,20 +392,11 @@ func createTestDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
CpuCount: 1,
|
||||
},
|
||||
Postgresql: databases.GetTestPostgresConfig(),
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
@@ -434,7 +427,18 @@ func createTestMySQLDatabase(
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
env := config.GetEnv()
|
||||
portStr := env.TestMysql80Port
|
||||
if portStr == "" {
|
||||
portStr = "33080"
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Failed to parse TEST_MYSQL_80_PORT: %v", err))
|
||||
}
|
||||
|
||||
testDbName := "testdb"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
@@ -442,9 +446,9 @@ func createTestMySQLDatabase(
|
||||
Mysql: &mysql.MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "password",
|
||||
Port: port,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -229,8 +229,8 @@ func (s *RestoreService) RestoreBackup(
|
||||
Mongodb: requestDTO.MongodbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
if err := restoringToDB.PopulateDbData(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database data: %w", err)
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
|
||||
@@ -149,101 +149,6 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadb_WithExcludeEvents_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container, err := connectToMariadbContainer(tools.MariadbVersion120, env.TestMariadb120Port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB 12.0 IsExcludeEvents test: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MariaDB ExcludeEvents Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseWithExcludeEventsViaAPI(
|
||||
t, router, "MariaDB ExcludeEvents Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
true,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_excludeevents"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
@@ -554,40 +459,18 @@ func createMariadbDatabaseViaAPI(
|
||||
database string,
|
||||
version tools.MariadbVersion,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
return createMariadbDatabaseWithExcludeEventsViaAPI(
|
||||
t, router, name, workspaceID,
|
||||
host, port, username, password, database,
|
||||
version, false, token,
|
||||
)
|
||||
}
|
||||
|
||||
func createMariadbDatabaseWithExcludeEventsViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
database string,
|
||||
version tools.MariadbVersion,
|
||||
isExcludeEvents bool,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypeMariadb,
|
||||
Mariadb: &mariadbtypes.MariadbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
IsExcludeEvents: isExcludeEvents,
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mysql_databases
|
||||
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE mariadb_databases
|
||||
ADD COLUMN privileges TEXT NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE mariadb_databases
|
||||
DROP COLUMN is_exclude_events;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mariadb_databases
|
||||
ADD COLUMN is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
ALTER TABLE mariadb_databases
|
||||
DROP COLUMN privileges;
|
||||
|
||||
ALTER TABLE mysql_databases
|
||||
DROP COLUMN privileges;
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,43 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE notifiers
|
||||
DROP CONSTRAINT fk_notifiers_workspace_id;
|
||||
|
||||
ALTER TABLE notifiers
|
||||
ADD CONSTRAINT fk_notifiers_workspace_id
|
||||
FOREIGN KEY (workspace_id)
|
||||
REFERENCES workspaces (id)
|
||||
ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE storages
|
||||
DROP CONSTRAINT fk_storages_workspace_id;
|
||||
|
||||
ALTER TABLE storages
|
||||
ADD CONSTRAINT fk_storages_workspace_id
|
||||
FOREIGN KEY (workspace_id)
|
||||
REFERENCES workspaces (id)
|
||||
ON DELETE CASCADE;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE notifiers
|
||||
DROP CONSTRAINT fk_notifiers_workspace_id;
|
||||
|
||||
ALTER TABLE notifiers
|
||||
ADD CONSTRAINT fk_notifiers_workspace_id
|
||||
FOREIGN KEY (workspace_id)
|
||||
REFERENCES workspaces (id);
|
||||
|
||||
ALTER TABLE storages
|
||||
DROP CONSTRAINT fk_storages_workspace_id;
|
||||
|
||||
ALTER TABLE storages
|
||||
ADD CONSTRAINT fk_storages_workspace_id
|
||||
FOREIGN KEY (workspace_id)
|
||||
REFERENCES workspaces (id);
|
||||
|
||||
-- +goose StatementEnd
|
||||
@@ -1,3 +1,4 @@
|
||||
export interface IsReadOnlyResponse {
|
||||
isReadOnly: boolean;
|
||||
privileges: string[];
|
||||
}
|
||||
|
||||
@@ -163,6 +163,10 @@ export const CreateDatabaseComponent = ({ workspaceId, onCreated, onClose }: Pro
|
||||
}
|
||||
|
||||
if (step === 'notifiers') {
|
||||
if (isCreating) {
|
||||
return <div>Creating database...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<EditDatabaseNotifiersComponent
|
||||
database={database}
|
||||
|
||||
@@ -11,6 +11,8 @@ interface Props {
|
||||
onContinue: () => void;
|
||||
}
|
||||
|
||||
const PRIVILEGES_TRUNCATE_LENGTH = 50;
|
||||
|
||||
export const CreateReadOnlyComponent = ({
|
||||
database,
|
||||
onReadOnlyUserUpdated,
|
||||
@@ -20,6 +22,8 @@ export const CreateReadOnlyComponent = ({
|
||||
const [isCheckingReadOnlyUser, setIsCheckingReadOnlyUser] = useState(false);
|
||||
const [isCreatingReadOnlyUser, setIsCreatingReadOnlyUser] = useState(false);
|
||||
const [isShowSkipConfirmation, setShowSkipConfirmation] = useState(false);
|
||||
const [privileges, setPrivileges] = useState<string[]>([]);
|
||||
const [isPrivilegesExpanded, setIsPrivilegesExpanded] = useState(false);
|
||||
|
||||
const isPostgres = database.type === DatabaseType.POSTGRES;
|
||||
const isMysql = database.type === DatabaseType.MYSQL;
|
||||
@@ -35,9 +39,12 @@ export const CreateReadOnlyComponent = ({
|
||||
? 'MongoDB'
|
||||
: 'database';
|
||||
|
||||
const privilegesLabel = isMongodb ? 'roles' : 'privileges';
|
||||
|
||||
const checkReadOnlyUser = async (): Promise<boolean> => {
|
||||
try {
|
||||
const response = await databaseApi.isUserReadOnly(database);
|
||||
setPrivileges(response.privileges || []);
|
||||
return response.isReadOnly;
|
||||
} catch (e) {
|
||||
alert((e as Error).message);
|
||||
@@ -45,6 +52,20 @@ export const CreateReadOnlyComponent = ({
|
||||
}
|
||||
};
|
||||
|
||||
const getPrivilegesDisplay = () => {
|
||||
const fullText = privileges.join(', ');
|
||||
if (isPrivilegesExpanded || fullText.length <= PRIVILEGES_TRUNCATE_LENGTH) {
|
||||
return fullText;
|
||||
}
|
||||
|
||||
return fullText.substring(0, PRIVILEGES_TRUNCATE_LENGTH) + '...';
|
||||
};
|
||||
|
||||
const shouldShowExpandToggle = () => {
|
||||
const fullText = privileges.join(', ');
|
||||
return fullText.length > PRIVILEGES_TRUNCATE_LENGTH;
|
||||
};
|
||||
|
||||
const createReadOnlyUser = async () => {
|
||||
setIsCreatingReadOnlyUser(true);
|
||||
|
||||
@@ -139,6 +160,31 @@ export const CreateReadOnlyComponent = ({
|
||||
<b>Read-only user allows to avoid storing credentials with write access at all</b>. Even
|
||||
in the worst case of hacking, nobody will be able to corrupt your data.
|
||||
</p>
|
||||
|
||||
<p className="mt-3">
|
||||
{privileges.length === 0 ? (
|
||||
<>
|
||||
Current user has <b>no write {privilegesLabel}</b>.
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Current user has the following write {privilegesLabel}:{' '}
|
||||
<span
|
||||
className={shouldShowExpandToggle() ? 'cursor-pointer hover:opacity-80' : ''}
|
||||
onClick={() =>
|
||||
shouldShowExpandToggle() && setIsPrivilegesExpanded(!isPrivilegesExpanded)
|
||||
}
|
||||
>
|
||||
{getPrivilegesDisplay()}
|
||||
{shouldShowExpandToggle() && (
|
||||
<span className="ml-1 text-xs text-blue-600 hover:opacity-80">
|
||||
({isPrivilegesExpanded ? 'collapse' : 'expand'})
|
||||
</span>
|
||||
)}
|
||||
</span>
|
||||
</>
|
||||
)}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="mt-5 flex">
|
||||
|
||||
Reference in New Issue
Block a user