Compare commits

...

12 Commits

Author SHA1 Message Date
Rostislav Dugin
2553203fcf Merge pull request #363 from databasus/develop
FIX (sign up): Return authorization token on sign up to avoid 2-step …
2026-02-15 00:09:00 +03:00
Rostislav Dugin
7b05bd8000 FIX (sign up): Return authorization token on sign up to avoid 2-step sign up 2026-02-15 00:08:01 +03:00
Rostislav Dugin
8d45728f73 Merge pull request #362 from databasus/develop
FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign …
2026-02-14 23:19:12 +03:00
Rostislav Dugin
c70ad82c95 FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign up \ password reset 2026-02-14 23:11:36 +03:00
Rostislav Dugin
e4bc34d319 Merge pull request #361 from databasus/develop
Develop
2026-02-13 16:57:25 +03:00
Rostislav Dugin
257ae85da7 FIX (postgres): Fix read-only issue when user cannot access tables and partitions created after user creation 2026-02-13 16:56:56 +03:00
Rostislav Dugin
b42c820bb2 FIX (mariadb): Fix events exclusion 2026-02-13 16:21:48 +03:00
Rostislav Dugin
da5c13fb11 Merge pull request #356 from databasus/develop
FIX (mysql & mariadb): Fix creation of backups with exremely large SQ…
2026-02-10 22:40:06 +03:00
Rostislav Dugin
35180360e5 FIX (mysql & mariadb): Fix creation of backups with exremely large SQL statements to avoid OOM 2026-02-10 22:38:18 +03:00
Rostislav Dugin
e4f6cd7a5d Merge pull request #349 from databasus/develop
Develop
2026-02-09 16:42:00 +03:00
Rostislav Dugin
d7b8e6d56a Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-09 16:40:46 +03:00
Rostislav Dugin
6016f23fb2 FEATURE (svr): Add SVR support 2026-02-09 16:39:51 +03:00
43 changed files with 1717 additions and 354 deletions

View File

@@ -268,7 +268,8 @@ window.__RUNTIME_CONFIG__ = {
IS_CLOUD: '\${IS_CLOUD:-false}',
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED'
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
};
JSEOF

View File

@@ -11,7 +11,7 @@
[![MongoDB](https://img.shields.io/badge/MongoDB-47A248?logo=mongodb&logoColor=white)](https://www.mongodb.com/)
<br />
[![Apache 2.0 License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Docker Pulls](https://img.shields.io/docker/pulls/databasus/databasus?color=brightgreen)](https://hub.docker.com/r/databasus/databasus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/databasus/databasus)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/databasus/databasus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/databasus/databasus)
@@ -31,8 +31,6 @@
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
</div>
---

View File

@@ -11,6 +11,9 @@ VICTORIA_LOGS_PASSWORD=devpassword
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# cloudflare turnstile
CLOUDFLARE_TURNSTILE_SITE_KEY=
CLOUDFLARE_TURNSTILE_SECRET_KEY=
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

View File

@@ -104,6 +104,10 @@ type EnvVariables struct {
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// Cloudflare Turnstile
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`

View File

@@ -1366,11 +1366,24 @@ func createTestBackup(
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(loadedStorages) == 0 {
panic("No storage found for workspace")
}
// Filter out system storages
var nonSystemStorages []*storages.Storage
for _, storage := range loadedStorages {
if !storage.IsSystem {
nonSystemStorages = append(nonSystemStorages, storage)
}
}
if len(nonSystemStorages) == 0 {
panic("No non-system storage found for workspace")
}
storages := nonSystemStorages
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,

View File

@@ -108,13 +108,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--single-transaction",
"--routines",
"--quick",
"--skip-extended-insert",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
args = append(args, "--events")
}

View File

@@ -107,6 +107,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--routines",
"--set-gtid-purged=OFF",
"--quick",
"--skip-extended-insert",
"--verbose",
}

View File

@@ -1164,12 +1164,13 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: port,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}

View File

@@ -25,13 +25,14 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.IsExcludeEvents = incoming.IsExcludeEvents
m.Privileges = incoming.Privileges
if incoming.Password != "" {

View File

@@ -26,12 +26,13 @@ type MongodbDatabase struct {
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Port *int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
}
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Port == 0 {
return errors.New("port is required")
if !m.IsSrv {
if m.Port == nil || *m.Port == 0 {
return errors.New("port is required for standard connections")
}
}
if m.Username == "" {
return errors.New("username is required")
}
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
if m.CpuCount <= 0 {
return errors.New("cpu count must be greater than 0")
}
return nil
}
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
m.Database = incoming.Database
m.AuthDatabase = incoming.AuthDatabase
m.IsHttps = incoming.IsHttps
m.IsSrv = incoming.IsSrv
m.CpuCount = incoming.CpuCount
if incoming.Password != "" {
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
tlsParams = "&tls=true&tlsInsecure=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Database,
authDB,
tlsParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
m.Database,
authDB,
tlsParams,
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
tlsParams = "&tls=true&tlsInsecure=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
authDB,
tlsParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
authDB,
tlsParams,
)

View File

@@ -64,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -133,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -442,15 +446,17 @@ func connectToMongodbContainer(
}
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
port := container.Port
return &MongodbDatabase{
Version: container.Version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: container.Username,
Password: container.Password,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}
@@ -489,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
strings.Contains(errStr, "permission denied"),
"Expected authorization error, got: %v", err)
}
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
}
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
}
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "localhost:27017")
}
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
assert.NotContains(t, uri, "/mydb")
}
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
assert.NotContains(t, uri, "/mydb")
}
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
CpuCount: 1,
}
err := model.Validate()
assert.NoError(t, err)
}
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
err := model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "port is required for standard connections")
}

View File

@@ -564,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 4: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
// Step 4: Discover schemas to grant privileges on
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
var rows pgx.Rows
if len(p.IncludeSchemas) > 0 {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
AND schema_name = ANY($1::text[])
`, p.IncludeSchemas)
} else {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
}
if err != nil {
return "", "", fmt.Errorf("failed to get schemas: %w", err)
}
@@ -619,50 +630,197 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
// Step 6: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, grantSelectSQL)
if err != nil {
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 7: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// First, set default privileges for objects created by the current user
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
if err != nil {
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 8: Verify user creation before committing
// Step 8: Discover all roles that own objects in each schema
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
// Filter by IncludeSchemas if specified.
type SchemaOwner struct {
SchemaName string
RoleName string
}
var ownerRows pgx.Rows
if len(p.IncludeSchemas) > 0 {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname = ANY($1::text[])
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`, p.IncludeSchemas)
} else {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`)
}
if err != nil {
// Log warning but continue - this is a best-effort enhancement
logger.Warn("Failed to query object owners for default privileges", "error", err)
} else {
var schemaOwners []SchemaOwner
for ownerRows.Next() {
var so SchemaOwner
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
ownerRows.Close()
logger.Warn("Failed to scan schema owner", "error", err)
break
}
schemaOwners = append(schemaOwners, so)
}
ownerRows.Close()
if err := ownerRows.Err(); err != nil {
logger.Warn("Error iterating schema owners", "error", err)
}
// Step 9: Set default privileges FOR ROLE for each object owner
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
// We log warnings but continue - user creation should succeed even if some roles can't be configured
for _, so := range schemaOwners {
// Try to set default privileges for tables
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (tables)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
// Try to set default privileges for sequences
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (sequences)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
}
if len(schemaOwners) > 0 {
logger.Info(
"Set default privileges for existing object owners",
"readonly_user",
baseUsername,
"owner_count",
len(schemaOwners),
)
}
}
// Step 10: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)

View File

@@ -1319,6 +1319,346 @@ type PostgresContainer struct {
DB *sqlx.DB
}
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create a second database user who will create tables
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err := container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Step 2: Grant the user_creator privileges to connect and create tables
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CREATE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
// Step 2b: Create an initial table by user_creator so they become an object owner
// This is important because our fix discovers existing object owners
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
initialTableName := fmt.Sprintf(
"public.initial_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('initial_data');
`, initialTableName, initialTableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
}()
// Step 3: NOW create read-only user via Databasus (as admin)
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
tableName := fmt.Sprintf(
"public.future_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
`, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
}()
// Step 5: Connect as read-only user and verify it can SELECT from the new table
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
var count int
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
assert.NoError(t, err)
assert.Equal(
t,
2,
count,
"Read-only user should be able to SELECT from table created by different user",
)
// Step 6: Verify read-only user cannot write to the table
_, err = readonlyConn.Exec(
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify pg_dump operations (LOCK TABLE) work
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
tx, err := readonlyConn.Begin()
assert.NoError(t, err)
defer tx.Rollback()
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
err = tx.Commit()
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create multiple schemas and tables
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS included_schema CASCADE;
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
CREATE SCHEMA included_schema;
CREATE SCHEMA excluded_schema;
CREATE TABLE public.public_table (id INT, data TEXT);
INSERT INTO public.public_table VALUES (1, 'public_data');
CREATE TABLE included_schema.included_table (id INT, data TEXT);
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
`)
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
}()
// Step 2: Create a second user who owns tables in both included and excluded schemas
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Grant privileges to user_creator
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
schema,
userCreatorUsername,
))
assert.NoError(t, err)
}
// User_creator creates tables in included and excluded schemas
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.user_table (id INT, data TEXT);
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
`)
assert.NoError(t, err)
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
pgModel := createPostgresModel(container)
pgModel.IncludeSchemas = []string{"public", "included_schema"}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: Connect as read-only user
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
// Step 5: Verify read-only user CAN access included schemas
var publicData string
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "public_data", publicData)
var includedData string
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "included_data", includedData)
var userIncludedData string
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "user_included_data", userIncludedData)
// Step 6: Verify read-only user CANNOT access excluded schema
var excludedData string
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
err = readonlyConn.Get(
&excludedData,
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify future tables in included schemas are accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.future_table (id INT, data TEXT);
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
`)
assert.NoError(t, err)
var futureData string
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(
t,
"future_data",
futureData,
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
)
// Step 8: Verify future tables in excluded schema are NOT accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
`)
assert.NoError(t, err)
var futureExcludedData string
err = readonlyConn.Get(
&futureExcludedData,
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(
t,
err.Error(),
"permission denied",
"Read-only user should NOT access tables in excluded schemas",
)
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"

View File

@@ -71,12 +71,13 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: port,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}

View File

@@ -32,7 +32,6 @@ import (
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
@@ -358,7 +357,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
backup = createTestBackup(mysqlDB, storage)
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
@@ -610,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
panic(err)
}
backup := createTestBackup(database, owner)
backup := createTestBackup(database, storage)
return database, backup
}
@@ -727,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
storage *storages.Storage,
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
@@ -759,7 +748,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(
if err := storage.SaveFile(
context.Background(),
fieldEncryptor,
logger,

View File

@@ -147,6 +147,26 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
}
}
func Test_BackupAndRestoreMariadb_WithExcludeEvents_EventsNotRestored(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testMariadbBackupRestoreWithExcludeEventsForVersion(t, tc.version, tc.port)
})
}
}
func testMariadbBackupRestoreForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
@@ -702,3 +722,145 @@ func updateMariadbDatabaseCredentialsViaAPI(
return &updatedDatabase
}
func testMariadbBackupRestoreWithExcludeEventsForVersion(
t *testing.T,
mariadbVersion tools.MariadbVersion,
port string,
) {
container, err := connectToMariadbContainer(mariadbVersion, port)
if err != nil {
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
return
}
defer func() {
if container.DB != nil {
container.DB.Close()
}
}()
setupMariadbTestData(t, container.DB)
_, err = container.DB.Exec(`
CREATE EVENT IF NOT EXISTS test_event
ON SCHEDULE EVERY 1 DAY
DO BEGIN
INSERT INTO test_data (name, value) VALUES ('event_test', 999);
END
`)
if err != nil {
t.Skipf(
"Skipping test: MariaDB version doesn't support events or event scheduler disabled: %v",
err,
)
return
}
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace(
"MariaDB Exclude Events Test Workspace",
user,
router,
)
storage := storages.CreateTestStorage(workspace.ID)
database := createMariadbDatabaseViaAPI(
t, router, "MariaDB Exclude Events Test Database", workspace.ID,
container.Host, container.Port,
container.Username, container.Password, container.Database,
container.Version,
user.Token,
)
database.Mariadb.IsExcludeEvents = true
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/update",
"Bearer "+user.Token,
database,
)
if w.Code != http.StatusOK {
t.Fatalf(
"Failed to update database with IsExcludeEvents. Status: %d, Body: %s",
w.Code,
w.Body.String(),
)
}
enableBackupsViaAPI(
t, router, database.ID, storage.ID,
backups_config.BackupEncryptionNone, user.Token,
)
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb_mariadb_no_events"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username, container.Password, container.Host, container.Port, newDBName)
newDB, err := sqlx.Connect("mysql", newDSN)
assert.NoError(t, err)
defer newDB.Close()
createMariadbRestoreViaAPI(
t, router, backup.ID,
container.Host, container.Port,
container.Username, container.Password, newDBName,
container.Version,
user.Token,
)
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
var tableExists int
err = newDB.Get(
&tableExists,
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
verifyMariadbDataIntegrity(t, container.DB, newDB)
var eventCount int
err = newDB.Get(
&eventCount,
"SELECT COUNT(*) FROM information_schema.events WHERE event_schema = ? AND event_name = 'test_event'",
newDBName,
)
assert.NoError(t, err)
assert.Equal(
t,
0,
eventCount,
"Event 'test_event' should NOT exist in restored database when IsExcludeEvents is true",
)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
t.Logf("Warning: Failed to delete backup file: %v", err)
}
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+user.Token,
http.StatusNoContent,
)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -385,13 +385,14 @@ func createMongodbDatabaseViaAPI(
Type: databases.DatabaseTypeMongodb,
Mongodb: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
Port: &port,
Username: username,
Password: password,
Database: database,
AuthDatabase: authDatabase,
Version: version,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
},
}
@@ -432,13 +433,14 @@ func createMongodbRestoreViaAPI(
request := restores_core.RestoreBackupRequest{
MongodbDatabase: &mongodbtypes.MongodbDatabase{
Host: host,
Port: port,
Port: &port,
Username: username,
Password: password,
Database: database,
AuthDatabase: authDatabase,
Version: version,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
},
}

View File

@@ -726,41 +726,28 @@ func Test_InviteUserToWorkspace_MembershipReceivedAfterSignUp(t *testing.T) {
assert.Equal(t, workspaces_dto.AddStatusInvited, inviteResponse.Status)
// 3. Sign up the invited user
// 3. Sign up the invited user (now returns token directly)
signUpRequest := users_dto.SignUpRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
Name: "Invited User",
}
resp := test_utils.MakePostRequest(
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
signUpRequest,
http.StatusOK,
)
assert.Contains(t, string(resp.Body), "User created successfully")
// 4. Sign in the newly registered user
signInRequest := users_dto.SignInRequestDTO{
Email: inviteEmail,
Password: "testpassword123",
}
var signInResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
signInRequest,
http.StatusOK,
&signInResponse,
)
// 5. Verify user is automatically added as member to workspace
assert.NotEmpty(t, signInResponse.Token)
assert.Equal(t, inviteEmail, signInResponse.Email)
// 4. Verify user is automatically added as member to workspace
var membersResponse workspaces_dto.GetMembersResponseDTO
test_utils.MakeGetRequestAndUnmarshal(
t,

View File

@@ -11,6 +11,7 @@ import (
user_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
cache_utils "databasus-backend/internal/util/cache"
cloudflare_turnstile "databasus-backend/internal/util/cloudflare_turnstile"
"github.com/gin-gonic/gin"
)
@@ -51,7 +52,7 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param request body users_dto.SignUpRequestDTO true "User signup data"
// @Success 200
// @Success 200 {object} users_dto.SignInResponseDTO
// @Failure 400
// @Router /users/signup [post]
func (c *UserController) SignUp(ctx *gin.Context) {
@@ -61,13 +62,41 @@ func (c *UserController) SignUp(ctx *gin.Context) {
return
}
err := c.userService.SignUp(&request)
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
user, err := c.userService.SignUp(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
response, err := c.userService.GenerateAccessToken(user)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
return
}
ctx.JSON(http.StatusOK, response)
}
// SignIn
@@ -88,6 +117,28 @@ func (c *UserController) SignIn(ctx *gin.Context) {
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
if !allowed {
ctx.JSON(
@@ -363,6 +414,28 @@ func (c *UserController) SendResetPasswordCode(ctx *gin.Context) {
return
}
// Verify Cloudflare Turnstile if enabled
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
if turnstileService.IsEnabled() {
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification required"},
)
return
}
clientIP := ctx.ClientIP()
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
if err != nil || !isValid {
ctx.JSON(
http.StatusBadRequest,
gin.H{"error": "Cloudflare Turnstile verification failed"},
)
return
}
}
allowed, _ := c.rateLimiter.CheckLimit(
request.Email,
"reset-password",

View File

@@ -27,7 +27,20 @@ func Test_SignUpUser_WithValidData_UserCreated(t *testing.T) {
Name: "Test User",
}
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", request, http.StatusOK)
var response users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signup",
"",
request,
http.StatusOK,
&response,
)
assert.NotEmpty(t, response.Token)
assert.NotEqual(t, uuid.Nil, response.UserID)
assert.Equal(t, request.Email, response.Email)
}
func Test_SignUpUser_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {

View File

@@ -9,14 +9,16 @@ import (
)
type SignUpRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required,min=8"`
Name string `json:"name" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInRequestDTO struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type SignInResponseDTO struct {
@@ -94,7 +96,8 @@ type OAuthCallbackResponseDTO struct {
}
type SendResetPasswordCodeRequestDTO struct {
Email string `json:"email" binding:"required,email"`
Email string `json:"email" binding:"required,email"`
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
}
type ResetPasswordRequestDTO struct {

View File

@@ -44,19 +44,19 @@ func (s *UserService) SetEmailSender(sender users_interfaces.EmailSender) {
s.emailSender = sender
}
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) (*users_models.User, error) {
existingUser, err := s.userRepository.GetUserByEmail(request.Email)
if err != nil {
return fmt.Errorf("failed to check existing user: %w", err)
return nil, fmt.Errorf("failed to check existing user: %w", err)
}
if existingUser != nil && existingUser.Status != users_enums.UserStatusInvited {
return errors.New("user with this email already exists")
return nil, errors.New("user with this email already exists")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash password: %w", err)
return nil, fmt.Errorf("failed to hash password: %w", err)
}
hashedPasswordStr := string(hashedPassword)
@@ -67,39 +67,45 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
existingUser.ID,
hashedPasswordStr,
); err != nil {
return fmt.Errorf("failed to set password: %w", err)
return nil, fmt.Errorf("failed to set password: %w", err)
}
if err := s.userRepository.UpdateUserStatus(
existingUser.ID,
users_enums.UserStatusActive,
); err != nil {
return fmt.Errorf("failed to activate user: %w", err)
return nil, fmt.Errorf("failed to activate user: %w", err)
}
name := request.Name
if err := s.userRepository.UpdateUserInfo(existingUser.ID, &name, nil); err != nil {
return fmt.Errorf("failed to update name: %w", err)
return nil, fmt.Errorf("failed to update name: %w", err)
}
// Fetch updated user to ensure we have the latest data
updatedUser, err := s.userRepository.GetUserByID(existingUser.ID)
if err != nil {
return nil, fmt.Errorf("failed to get updated user: %w", err)
}
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Invited user completed registration: %s", existingUser.Email),
&existingUser.ID,
fmt.Sprintf("Invited user completed registration: %s", updatedUser.Email),
&updatedUser.ID,
nil,
)
return nil
return updatedUser, nil
}
// Get settings to check registration policy for new users
settings, err := s.settingsService.GetSettings()
if err != nil {
return fmt.Errorf("failed to get settings: %w", err)
return nil, fmt.Errorf("failed to get settings: %w", err)
}
// Check if external registrations are allowed
if !settings.IsAllowExternalRegistrations {
return errors.New("external registration is disabled")
return nil, errors.New("external registration is disabled")
}
user := &users_models.User{
@@ -114,7 +120,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
}
if err := s.userRepository.CreateUser(user); err != nil {
return fmt.Errorf("failed to create user: %w", err)
return nil, fmt.Errorf("failed to create user: %w", err)
}
s.auditLogWriter.WriteAuditLog(
@@ -123,7 +129,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
nil,
)
return nil
return user, nil
}
func (s *UserService) SignIn(
@@ -258,6 +264,7 @@ func (s *UserService) GenerateAccessToken(
return &users_dto.SignInResponseDTO{
UserID: user.ID,
Email: user.Email,
Token: tokenString,
}, nil
}
@@ -463,6 +470,178 @@ func (s *UserService) HandleGitHubOAuth(
)
}
func (s *UserService) HandleGoogleOAuth(
code, redirectUri string,
) (*users_dto.OAuthCallbackResponseDTO, error) {
return s.handleGoogleOAuthWithEndpoint(
code,
redirectUri,
google.Endpoint,
"https://www.googleapis.com/oauth2/v2/userinfo",
)
}
func (s *UserService) SendResetPasswordCode(email string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Silently succeed for non-existent users to prevent enumeration attacks
if user == nil {
return nil
}
// Only active users can reset passwords
if user.Status != users_enums.UserStatusActive {
return errors.New("only active users can reset their password")
}
// Check rate limiting - max 3 codes per hour
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
if err != nil {
return fmt.Errorf("failed to check rate limit: %w", err)
}
if recentCount >= 3 {
return errors.New("too many password reset attempts, please try again later")
}
// Generate 6-digit random code using crypto/rand for better randomness
codeNum := make([]byte, 4)
_, err = io.ReadFull(rand.Reader, codeNum)
if err != nil {
return fmt.Errorf("failed to generate random code: %w", err)
}
// Convert bytes to uint32 and modulo to get 6 digits
randomInt := uint32(
codeNum[0],
)<<24 | uint32(
codeNum[1],
)<<16 | uint32(
codeNum[2],
)<<8 | uint32(
codeNum[3],
)
code := fmt.Sprintf("%06d", randomInt%1000000)
// Hash the code
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash code: %w", err)
}
// Store in database with 1 hour expiration
resetCode := &users_models.PasswordResetCode{
ID: uuid.New(),
UserID: user.ID,
HashedCode: string(hashedCode),
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
IsUsed: false,
CreatedAt: time.Now().UTC(),
}
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
return fmt.Errorf("failed to create reset code: %w", err)
}
// Send email with code
if s.emailSender != nil {
subject := "Password Reset Code"
body := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
You have requested to reset your password. Please use the following code to complete the password reset process:
</p>
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
</div>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
This code will expire in <strong>1 hour</strong>.
</p>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
</p>
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
This is an automated message. Please do not reply to this email.
</p>
</div>
</body>
</html>
`, code)
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Password reset code sent to: %s", user.Email),
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) ResetPassword(email, code, newPassword string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return errors.New("user with this email does not exist")
}
// Get valid reset code for user
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
if err != nil {
return errors.New("invalid or expired reset code")
}
// Verify code matches
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
if err != nil {
return errors.New("invalid reset code")
}
// Mark code as used
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
}
// Update user password
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
"Password reset via email code",
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) handleGitHubOAuthWithEndpoint(
code, redirectUri string,
endpoint oauth2.Endpoint,
@@ -529,17 +708,6 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
return s.getOrCreateUserFromOAuth(oauthID, email, name, "github")
}
func (s *UserService) HandleGoogleOAuth(
code, redirectUri string,
) (*users_dto.OAuthCallbackResponseDTO, error) {
return s.handleGoogleOAuthWithEndpoint(
code,
redirectUri,
google.Endpoint,
"https://www.googleapis.com/oauth2/v2/userinfo",
)
}
func (s *UserService) handleGoogleOAuthWithEndpoint(
code, redirectUri string,
endpoint oauth2.Endpoint,
@@ -805,164 +973,3 @@ func (s *UserService) fetchGitHubPrimaryEmail(
return "", errors.New("github account has no accessible email")
}
func (s *UserService) SendResetPasswordCode(email string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
// Silently succeed for non-existent users to prevent enumeration attacks
if user == nil {
return nil
}
// Only active users can reset passwords
if user.Status != users_enums.UserStatusActive {
return errors.New("only active users can reset their password")
}
// Check rate limiting - max 3 codes per hour
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
if err != nil {
return fmt.Errorf("failed to check rate limit: %w", err)
}
if recentCount >= 3 {
return errors.New("too many password reset attempts, please try again later")
}
// Generate 6-digit random code using crypto/rand for better randomness
codeNum := make([]byte, 4)
_, err = io.ReadFull(rand.Reader, codeNum)
if err != nil {
return fmt.Errorf("failed to generate random code: %w", err)
}
// Convert bytes to uint32 and modulo to get 6 digits
randomInt := uint32(
codeNum[0],
)<<24 | uint32(
codeNum[1],
)<<16 | uint32(
codeNum[2],
)<<8 | uint32(
codeNum[3],
)
code := fmt.Sprintf("%06d", randomInt%1000000)
// Hash the code
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash code: %w", err)
}
// Store in database with 1 hour expiration
resetCode := &users_models.PasswordResetCode{
ID: uuid.New(),
UserID: user.ID,
HashedCode: string(hashedCode),
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
IsUsed: false,
CreatedAt: time.Now().UTC(),
}
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
return fmt.Errorf("failed to create reset code: %w", err)
}
// Send email with code
if s.emailSender != nil {
subject := "Password Reset Code"
body := fmt.Sprintf(`
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
</head>
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
You have requested to reset your password. Please use the following code to complete the password reset process:
</p>
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
</div>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
This code will expire in <strong>1 hour</strong>.
</p>
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
</p>
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
This is an automated message. Please do not reply to this email.
</p>
</div>
</body>
</html>
`, code)
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
fmt.Sprintf("Password reset code sent to: %s", user.Email),
&user.ID,
nil,
)
}
return nil
}
func (s *UserService) ResetPassword(email, code, newPassword string) error {
user, err := s.userRepository.GetUserByEmail(email)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if user == nil {
return errors.New("user with this email does not exist")
}
// Get valid reset code for user
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
if err != nil {
return errors.New("invalid or expired reset code")
}
// Verify code matches
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
if err != nil {
return errors.New("invalid reset code")
}
// Mark code as used
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
return fmt.Errorf("failed to mark code as used: %w", err)
}
// Update user password
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// Audit log
if s.auditLogWriter != nil {
s.auditLogWriter.WriteAuditLog(
"Password reset via email code",
&user.ID,
nil,
)
}
return nil
}

View File

@@ -0,0 +1,71 @@
package cloudflare_turnstile
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"time"
)
type CloudflareTurnstileService struct {
secretKey string
siteKey string
}
type cloudflareTurnstileResponse struct {
Success bool `json:"success"`
ChallengeTS time.Time `json:"challenge_ts"`
Hostname string `json:"hostname"`
ErrorCodes []string `json:"error-codes"`
}
const cloudflareTurnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
func (s *CloudflareTurnstileService) IsEnabled() bool {
return s.secretKey != ""
}
func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool, error) {
if !s.IsEnabled() {
return true, nil
}
if token == "" {
return false, errors.New("cloudflare Turnstile token is required")
}
formData := url.Values{}
formData.Set("secret", s.secretKey)
formData.Set("response", token)
formData.Set("remoteip", remoteIP)
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
if err != nil {
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return false, fmt.Errorf("failed to read Cloudflare Turnstile response: %w", err)
}
var turnstileResp cloudflareTurnstileResponse
if err := json.Unmarshal(body, &turnstileResp); err != nil {
return false, fmt.Errorf("failed to parse Cloudflare Turnstile response: %w", err)
}
if !turnstileResp.Success {
return false, fmt.Errorf(
"cloudflare Turnstile verification failed: %v",
turnstileResp.ErrorCodes,
)
}
return true, nil
}

View File

@@ -0,0 +1,14 @@
package cloudflare_turnstile
import (
"databasus-backend/internal/config"
)
var cloudflareTurnstileService = &CloudflareTurnstileService{
config.GetEnv().CloudflareTurnstileSecretKey,
config.GetEnv().CloudflareTurnstileSiteKey,
}
func GetCloudflareTurnstileService() *CloudflareTurnstileService {
return cloudflareTurnstileService
}

View File

@@ -0,0 +1,17 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mongodb_databases ALTER COLUMN port DROP NOT NULL;
-- +goose StatementEnd
-- +goose StatementBegin
ALTER TABLE mongodb_databases ADD COLUMN is_srv BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mongodb_databases DROP COLUMN is_srv;
-- +goose StatementEnd
-- +goose StatementBegin
ALTER TABLE mongodb_databases ALTER COLUMN port SET NOT NULL;
-- +goose StatementEnd

View File

@@ -0,0 +1,11 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE mariadb_databases
ADD COLUMN IF NOT EXISTS is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE mariadb_databases
DROP COLUMN IF EXISTS is_exclude_events;
-- +goose StatementEnd

View File

@@ -2,4 +2,5 @@ MODE=development
VITE_GITHUB_CLIENT_ID=
VITE_GOOGLE_CLIENT_ID=
VITE_IS_EMAIL_CONFIGURED=false
VITE_IS_CLOUD=false
VITE_IS_CLOUD=false
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=

View File

@@ -3,6 +3,7 @@ interface RuntimeConfig {
GITHUB_CLIENT_ID?: string;
GOOGLE_CLIENT_ID?: string;
IS_EMAIL_CONFIGURED?: string;
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
}
declare global {
@@ -39,6 +40,11 @@ export const IS_EMAIL_CONFIGURED =
window.__RUNTIME_CONFIG__?.IS_EMAIL_CONFIGURED === 'true' ||
import.meta.env.VITE_IS_EMAIL_CONFIGURED === 'true';
export const CLOUDFLARE_TURNSTILE_SITE_KEY =
window.__RUNTIME_CONFIG__?.CLOUDFLARE_TURNSTILE_SITE_KEY ||
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
'';
export function getOAuthRedirectUri(): string {
return `${window.location.origin}/auth/callback`;
}

View File

@@ -32,6 +32,7 @@ describe('MongodbConnectionStringParser', () => {
expect(result.database).toBe('mydb');
expect(result.authDatabase).toBe('admin');
expect(result.useTls).toBe(false);
expect(result.isSrv).toBe(false);
});
it('should parse connection string without database', () => {
@@ -46,6 +47,7 @@ describe('MongodbConnectionStringParser', () => {
expect(result.database).toBe('');
expect(result.authDatabase).toBe('admin');
expect(result.useTls).toBe(false);
expect(result.isSrv).toBe(false);
});
it('should default port to 27017 when not specified', () => {
@@ -107,6 +109,7 @@ describe('MongodbConnectionStringParser', () => {
expect(result.password).toBe('atlaspass');
expect(result.database).toBe('mydb');
expect(result.useTls).toBe(true); // SRV connections use TLS by default
expect(result.isSrv).toBe(true);
});
it('should parse mongodb+srv:// without database', () => {
@@ -119,6 +122,7 @@ describe('MongodbConnectionStringParser', () => {
expect(result.host).toBe('cluster0.abc123.mongodb.net');
expect(result.database).toBe('');
expect(result.useTls).toBe(true);
expect(result.isSrv).toBe(true);
});
});
@@ -314,13 +318,15 @@ describe('MongodbConnectionStringParser', () => {
expect(result.format).toBe('key-value');
});
it('should return error for key-value format missing password', () => {
const result = expectError(
it('should allow missing password in key-value format (returns empty password)', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse('host=localhost database=mydb user=admin'),
);
expect(result.error).toContain('Password');
expect(result.format).toBe('key-value');
expect(result.host).toBe('localhost');
expect(result.username).toBe('admin');
expect(result.password).toBe('');
expect(result.database).toBe('mydb');
});
});
@@ -351,12 +357,15 @@ describe('MongodbConnectionStringParser', () => {
expect(result.error).toContain('Username');
});
it('should return error for missing password in URI', () => {
const result = expectError(
it('should allow missing password in URI (returns empty password)', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse('mongodb://user@host:27017/db'),
);
expect(result.error).toContain('Password');
expect(result.username).toBe('user');
expect(result.password).toBe('');
expect(result.host).toBe('host');
expect(result.database).toBe('db');
});
it('should return error for mysql:// format (wrong database type)', () => {
@@ -446,4 +455,67 @@ describe('MongodbConnectionStringParser', () => {
expect(result.database).toBe('');
});
});
describe('Password Placeholder Handling', () => {
it('should treat <db_password> placeholder as empty password in URI format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse('mongodb://user:<db_password>@host:27017/db'),
);
expect(result.username).toBe('user');
expect(result.password).toBe('');
expect(result.host).toBe('host');
expect(result.database).toBe('db');
});
it('should treat <password> placeholder as empty password in URI format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse('mongodb://user:<password>@host:27017/db'),
);
expect(result.username).toBe('user');
expect(result.password).toBe('');
expect(result.host).toBe('host');
expect(result.database).toBe('db');
});
it('should treat <db_password> placeholder as empty password in SRV format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'mongodb+srv://user:<db_password>@cluster0.mongodb.net/db',
),
);
expect(result.username).toBe('user');
expect(result.password).toBe('');
expect(result.host).toBe('cluster0.mongodb.net');
expect(result.isSrv).toBe(true);
});
it('should treat <db_password> placeholder as empty password in key-value format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'host=localhost database=mydb user=admin password=<db_password>',
),
);
expect(result.host).toBe('localhost');
expect(result.username).toBe('admin');
expect(result.password).toBe('');
expect(result.database).toBe('mydb');
});
it('should treat <password> placeholder as empty password in key-value format', () => {
const result = expectSuccess(
MongodbConnectionStringParser.parse(
'host=localhost database=mydb user=admin password=<password>',
),
);
expect(result.host).toBe('localhost');
expect(result.username).toBe('admin');
expect(result.password).toBe('');
expect(result.database).toBe('mydb');
});
});
});

View File

@@ -6,6 +6,7 @@ export type ParseResult = {
database: string;
authDatabase: string;
useTls: boolean;
isSrv: boolean;
};
export type ParseError = {
@@ -63,7 +64,8 @@ export class MongodbConnectionStringParser {
const host = url.hostname;
const port = url.port ? parseInt(url.port, 10) : isSrv ? 27017 : 27017;
const username = decodeURIComponent(url.username);
const password = decodeURIComponent(url.password);
const rawPassword = decodeURIComponent(url.password);
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword;
const database = decodeURIComponent(url.pathname.slice(1));
const authDatabase = this.getAuthSource(url.search) || 'admin';
const useTls = isSrv ? true : this.checkTlsMode(url.search);
@@ -76,10 +78,6 @@ export class MongodbConnectionStringParser {
return { error: 'Username is missing from connection string' };
}
if (!password) {
return { error: 'Password is missing from connection string' };
}
return {
host,
port,
@@ -88,6 +86,7 @@ export class MongodbConnectionStringParser {
database: database || '',
authDatabase,
useTls,
isSrv,
};
} catch (e) {
return {
@@ -114,7 +113,8 @@ export class MongodbConnectionStringParser {
const port = params['port'];
const database = params['database'] || params['dbname'] || params['db'];
const username = params['user'] || params['username'];
const password = params['password'];
const rawPassword = params['password'];
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword || '';
const authDatabase = params['authSource'] || params['authDatabase'] || 'admin';
const tls = params['tls'] || params['ssl'];
@@ -132,13 +132,6 @@ export class MongodbConnectionStringParser {
};
}
if (!password) {
return {
error: 'Password is missing from connection string. Use password=yourpassword',
format: 'key-value',
};
}
const useTls = this.isTlsEnabled(tls);
return {
@@ -149,6 +142,7 @@ export class MongodbConnectionStringParser {
database: database || '',
authDatabase,
useTls,
isSrv: false,
};
} catch (e) {
return {
@@ -191,4 +185,11 @@ export class MongodbConnectionStringParser {
const enabledValues = ['true', 'yes', '1'];
return enabledValues.includes(lowercased);
}
private static isPasswordPlaceholder(password: string | null | undefined): boolean {
if (!password) return false;
const trimmed = password.trim();
return trimmed === '<db_password>' || trimmed === '<password>';
}
}

View File

@@ -10,5 +10,6 @@ export interface MongodbDatabase {
database: string;
authDatabase: string;
isHttps: boolean;
isSrv: boolean;
cpuCount: number;
}

View File

@@ -31,10 +31,18 @@ const notifyAuthListeners = () => {
};
export const userApi = {
async signUp(signUpRequest: SignUpRequest) {
async signUp(signUpRequest: SignUpRequest): Promise<SignInResponse> {
const requestOptions: RequestOptions = new RequestOptions();
requestOptions.setBody(JSON.stringify(signUpRequest));
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/users/signup`, requestOptions);
return apiHelper
.fetchPostJson(`${getApplicationServer()}/api/v1/users/signup`, requestOptions)
.then((response: unknown): SignInResponse => {
const typedResponse = response as SignInResponse;
saveAuthorizedData(typedResponse.token, typedResponse.userId);
notifyAuthListeners();
return typedResponse;
});
},
async signIn(signInRequest: SignInRequest): Promise<SignInResponse> {

View File

@@ -1,3 +1,4 @@
export interface SendResetPasswordCodeRequest {
email: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -1,4 +1,5 @@
export interface SignInRequest {
email: string;
password: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -2,4 +2,5 @@ export interface SignUpRequest {
email: string;
password: string;
name: string;
cloudflareTurnstileToken?: string;
}

View File

@@ -46,7 +46,7 @@ export const EditMongoDbSpecificDataComponent = ({
const [isTestingConnection, setIsTestingConnection] = useState(false);
const [isConnectionFailed, setIsConnectionFailed] = useState(false);
const hasAdvancedValues = !!database.mongodb?.authDatabase;
const hasAdvancedValues = !!database.mongodb?.authDatabase || !!database.mongodb?.isSrv;
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
const parseFromClipboard = async () => {
@@ -75,17 +75,29 @@ export const EditMongoDbSpecificDataComponent = ({
host: result.host,
port: result.port,
username: result.username,
password: result.password,
password: result.password || '',
database: result.database,
authDatabase: result.authDatabase,
isHttps: result.useTls,
isSrv: result.isSrv,
cpuCount: 1,
},
};
if (result.isSrv) {
setShowAdvanced(true);
}
setEditingDatabase(updatedDatabase);
setIsConnectionTested(false);
message.success('Connection string parsed successfully');
if (!result.password) {
message.warning(
'Connection string parsed successfully. Please enter the password manually.',
);
} else {
message.success('Connection string parsed successfully');
}
} catch {
message.error('Failed to read clipboard. Please check browser permissions.');
}
@@ -156,9 +168,11 @@ export const EditMongoDbSpecificDataComponent = ({
if (!editingDatabase) return null;
const isSrvConnection = editingDatabase.mongodb?.isSrv || false;
let isAllFieldsFilled = true;
if (!editingDatabase.mongodb?.host) isAllFieldsFilled = false;
if (!editingDatabase.mongodb?.port) isAllFieldsFilled = false;
if (!isSrvConnection && !editingDatabase.mongodb?.port) isAllFieldsFilled = false;
if (!editingDatabase.mongodb?.username) isAllFieldsFilled = false;
if (!editingDatabase.id && !editingDatabase.mongodb?.password) isAllFieldsFilled = false;
if (!editingDatabase.mongodb?.database) isAllFieldsFilled = false;
@@ -220,25 +234,27 @@ export const EditMongoDbSpecificDataComponent = ({
</div>
)}
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Port</div>
<InputNumber
type="number"
value={editingDatabase.mongodb?.port}
onChange={(e) => {
if (!editingDatabase.mongodb || e === null) return;
{!isSrvConnection && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Port</div>
<InputNumber
type="number"
value={editingDatabase.mongodb?.port}
onChange={(e) => {
if (!editingDatabase.mongodb || e === null) return;
setEditingDatabase({
...editingDatabase,
mongodb: { ...editingDatabase.mongodb, port: e },
});
setIsConnectionTested(false);
}}
size="small"
className="max-w-[200px] grow"
placeholder="27017"
/>
</div>
setEditingDatabase({
...editingDatabase,
mongodb: { ...editingDatabase.mongodb, port: e },
});
setIsConnectionTested(false);
}}
size="small"
className="max-w-[200px] grow"
placeholder="27017"
/>
</div>
)}
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Username</div>
@@ -366,6 +382,31 @@ export const EditMongoDbSpecificDataComponent = ({
{isShowAdvanced && (
<>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Use SRV connection</div>
<div className="flex items-center">
<Switch
checked={editingDatabase.mongodb?.isSrv || false}
onChange={(checked) => {
if (!editingDatabase.mongodb) return;
setEditingDatabase({
...editingDatabase,
mongodb: { ...editingDatabase.mongodb, isSrv: checked },
});
setIsConnectionTested(false);
}}
size="small"
/>
<Tooltip
className="cursor-pointer"
title="Enable for MongoDB Atlas SRV connections (mongodb+srv://). Port is not required for SRV connections."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Auth database</div>
<Input

View File

@@ -1,9 +1,12 @@
import { Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
interface RequestResetPasswordComponentProps {
onSwitchToSignIn?: () => void;
@@ -20,6 +23,8 @@ export function RequestResetPasswordComponent({
const [error, setError] = useState('');
const [successMessage, setSuccessMessage] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateEmail = (): boolean => {
if (!email) {
setEmailError(true);
@@ -42,7 +47,10 @@ export function RequestResetPasswordComponent({
setLoading(true);
try {
const response = await userApi.sendResetPasswordCode({ email });
const response = await userApi.sendResetPasswordCode({
email,
cloudflareTurnstileToken: token,
});
setSuccessMessage(response.message);
// After successful code send, switch to reset password form
@@ -53,6 +61,7 @@ export function RequestResetPasswordComponent({
}, 2000);
} catch (e) {
setError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
setLoading(false);
@@ -84,6 +93,8 @@ export function RequestResetPasswordComponent({
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
import { Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID, IS_EMAIL_CONFIGURED } from '../../../constants';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
@@ -29,6 +32,8 @@ export function SignInComponent({
const [signInError, setSignInError] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateFieldsForSignIn = (): boolean => {
if (!email) {
setEmailError(true);
@@ -59,9 +64,11 @@ export function SignInComponent({
await userApi.signIn({
email,
password,
cloudflareTurnstileToken: token,
});
} catch (e) {
setSignInError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
setLoading(false);
@@ -119,6 +126,8 @@ export function SignInComponent({
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
import { App, Button, Input } from 'antd';
import { type JSX, useState } from 'react';
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
import { userApi } from '../../../entity/users';
import { StringUtils } from '../../../shared/lib';
import { FormValidator } from '../../../shared/lib/FormValidator';
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
@@ -31,6 +34,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
const [signUpError, setSignUpError] = useState('');
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
const validateFieldsForSignUp = (): boolean => {
if (!name || name.trim() === '') {
setNameError(true);
@@ -85,10 +90,11 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
email,
password,
name,
cloudflareTurnstileToken: token,
});
await userApi.signIn({ email, password });
} catch (e) {
setSignUpError(StringUtils.capitalizeFirstLetter((e as Error).message));
resetCloudflareTurnstile();
}
}
@@ -173,6 +179,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
<div className="mt-3" />
<CloudflareTurnstileWidget containerRef={containerRef} />
<Button
disabled={isLoading}
loading={isLoading}

View File

@@ -1,7 +1,6 @@
import dayjs from 'dayjs';
import relativeTime from 'dayjs/plugin/relativeTime';
import utc from 'dayjs/plugin/utc';
import { StrictMode } from 'react';
import { createRoot } from 'react-dom/client';
import './index.css';
@@ -11,8 +10,4 @@ import App from './App.tsx';
dayjs.extend(utc);
dayjs.extend(relativeTime);
createRoot(document.getElementById('root')!).render(
<StrictMode>
<App />
</StrictMode>,
);
createRoot(document.getElementById('root')!).render(<App />);

View File

@@ -0,0 +1,116 @@
import { useEffect, useRef, useState } from 'react';
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
declare global {
interface Window {
turnstile?: {
render: (
container: string | HTMLElement,
options: {
sitekey: string;
callback: (token: string) => void;
'error-callback'?: () => void;
'expired-callback'?: () => void;
theme?: 'light' | 'dark' | 'auto';
size?: 'normal' | 'compact' | 'flexible';
appearance?: 'always' | 'execute' | 'interaction-only';
},
) => string;
reset: (widgetId: string) => void;
remove: (widgetId: string) => void;
getResponse: (widgetId: string) => string | undefined;
};
}
}
interface UseCloudflareTurnstileReturn {
containerRef: React.RefObject<HTMLDivElement | null>;
token: string | undefined;
resetCloudflareTurnstile: () => void;
}
const loadCloudflareTurnstileScript = (): Promise<void> => {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
return Promise.resolve();
}
return new Promise((resolve, reject) => {
if (document.querySelector('script[src*="turnstile"]')) {
resolve();
return;
}
const script = document.createElement('script');
script.src = 'https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit';
script.async = true;
script.defer = true;
script.onload = () => resolve();
script.onerror = () => reject(new Error('Failed to load Cloudflare Turnstile'));
document.head.appendChild(script);
});
};
export function useCloudflareTurnstile(): UseCloudflareTurnstileReturn {
const [token, setToken] = useState<string | undefined>(undefined);
const containerRef = useRef<HTMLDivElement | null>(null);
const widgetIdRef = useRef<string | null>(null);
useEffect(() => {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY || !containerRef.current) {
return;
}
loadCloudflareTurnstileScript()
.then(() => {
if (!window.turnstile || !containerRef.current) {
return;
}
try {
const widgetId = window.turnstile.render(containerRef.current, {
sitekey: CLOUDFLARE_TURNSTILE_SITE_KEY,
callback: (receivedToken: string) => {
setToken(receivedToken);
},
'error-callback': () => {
setToken(undefined);
},
'expired-callback': () => {
setToken(undefined);
},
theme: 'auto',
size: 'normal',
appearance: 'execute',
});
widgetIdRef.current = widgetId;
} catch (error) {
console.error('Failed to render Cloudflare Turnstile widget:', error);
}
})
.catch((error) => {
console.error('Failed to load Cloudflare Turnstile:', error);
});
return () => {
if (widgetIdRef.current && window.turnstile) {
window.turnstile.remove(widgetIdRef.current);
widgetIdRef.current = null;
}
};
}, []);
const resetCloudflareTurnstile = () => {
if (widgetIdRef.current && window.turnstile) {
window.turnstile.reset(widgetIdRef.current);
setToken(undefined);
}
};
return {
containerRef,
token,
resetCloudflareTurnstile,
};
}

View File

@@ -0,0 +1,17 @@
import { type JSX } from 'react';
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
interface CloudflareTurnstileWidgetProps {
containerRef: React.RefObject<HTMLDivElement | null>;
}
export function CloudflareTurnstileWidget({
containerRef,
}: CloudflareTurnstileWidgetProps): JSX.Element | null {
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
return null;
}
return <div ref={containerRef} className="mb-3" />;
}

View File

@@ -1,3 +1,4 @@
export { CloudflareTurnstileWidget } from './CloudflareTurnstileWidget';
export { ConfirmationComponent } from './ConfirmationComponent';
export { StarButtonComponent } from './StarButtonComponent';
export { ThemeToggleComponent } from './ThemeToggleComponent';