mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 | ||
|
|
e4bc34d319 | ||
|
|
257ae85da7 | ||
|
|
b42c820bb2 | ||
|
|
da5c13fb11 | ||
|
|
35180360e5 | ||
|
|
e4f6cd7a5d | ||
|
|
d7b8e6d56a | ||
|
|
6016f23fb2 |
@@ -268,7 +268,8 @@ window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED'
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@@ -11,6 +11,9 @@ VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
@@ -104,6 +104,10 @@ type EnvVariables struct {
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
|
||||
@@ -1366,11 +1366,24 @@ func createTestBackup(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
loadedStorages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(loadedStorages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
// Filter out system storages
|
||||
var nonSystemStorages []*storages.Storage
|
||||
for _, storage := range loadedStorages {
|
||||
if !storage.IsSystem {
|
||||
nonSystemStorages = append(nonSystemStorages, storage)
|
||||
}
|
||||
}
|
||||
if len(nonSystemStorages) == 0 {
|
||||
panic("No non-system storage found for workspace")
|
||||
}
|
||||
|
||||
storages := nonSystemStorages
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
|
||||
@@ -108,13 +108,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.HasPrivilege("TRIGGER") {
|
||||
args = append(args, "--triggers")
|
||||
}
|
||||
if mdb.HasPrivilege("EVENT") {
|
||||
|
||||
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
|
||||
args = append(args, "--events")
|
||||
}
|
||||
|
||||
|
||||
@@ -107,6 +107,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
|
||||
"--routines",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--skip-extended-insert",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
|
||||
@@ -1164,12 +1164,13 @@ func getTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,13 +25,14 @@ type MariadbDatabase struct {
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
|
||||
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsExcludeEvents = incoming.IsExcludeEvents
|
||||
m.Privileges = incoming.Privileges
|
||||
|
||||
if incoming.Password != "" {
|
||||
|
||||
@@ -26,12 +26,13 @@ type MongodbDatabase struct {
|
||||
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Port *int `json:"port" gorm:"type:int"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database string `json:"database" gorm:"type:text;not null"`
|
||||
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
|
||||
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
|
||||
}
|
||||
|
||||
@@ -43,9 +44,13 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
|
||||
if !m.IsSrv {
|
||||
if m.Port == nil || *m.Port == 0 {
|
||||
return errors.New("port is required for standard connections")
|
||||
}
|
||||
}
|
||||
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
@@ -58,6 +63,7 @@ func (m *MongodbDatabase) Validate() error {
|
||||
if m.CpuCount <= 0 {
|
||||
return errors.New("cpu count must be greater than 0")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -125,6 +131,7 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
|
||||
m.Database = incoming.Database
|
||||
m.AuthDatabase = incoming.AuthDatabase
|
||||
m.IsHttps = incoming.IsHttps
|
||||
m.IsSrv = incoming.IsSrv
|
||||
m.CpuCount = incoming.CpuCount
|
||||
|
||||
if incoming.Password != "" {
|
||||
@@ -455,12 +462,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/%s?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
m.Database,
|
||||
authDB,
|
||||
tlsParams,
|
||||
@@ -479,12 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
tlsParams = "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
m.Port,
|
||||
port,
|
||||
authDB,
|
||||
tlsParams,
|
||||
)
|
||||
|
||||
@@ -64,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: limitedUsername,
|
||||
Password: limitedPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -133,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
|
||||
|
||||
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
|
||||
|
||||
port := container.Port
|
||||
mongodbModel := &MongodbDatabase{
|
||||
Version: tc.version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: backupUsername,
|
||||
Password: backupPassword,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
@@ -442,15 +446,17 @@ func connectToMongodbContainer(
|
||||
}
|
||||
|
||||
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
|
||||
port := container.Port
|
||||
return &MongodbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Port: &port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: container.Database,
|
||||
AuthDatabase: container.AuthDatabase,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
@@ -489,3 +495,157 @@ func assertWriteDenied(t *testing.T, err error) {
|
||||
strings.Contains(errStr, "permission denied"),
|
||||
"Expected authorization error, got: %v", err)
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/mydb")
|
||||
assert.Contains(t, uri, "authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
}
|
||||
|
||||
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.buildConnectionURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb+srv://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "cluster0.example.mongodb.net")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, ":27017")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
|
||||
port := 27017
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: &port,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
}
|
||||
|
||||
uri := model.BuildMongodumpURI("testpass123")
|
||||
|
||||
assert.Contains(t, uri, "mongodb://")
|
||||
assert.Contains(t, uri, "testuser")
|
||||
assert.Contains(t, uri, "testpass123")
|
||||
assert.Contains(t, uri, "localhost:27017")
|
||||
assert.Contains(t, uri, "/?authSource=admin")
|
||||
assert.Contains(t, uri, "connectTimeoutMS=15000")
|
||||
assert.NotContains(t, uri, "mongodb+srv://")
|
||||
assert.NotContains(t, uri, "/mydb")
|
||||
}
|
||||
|
||||
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "cluster0.example.mongodb.net",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: true,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
|
||||
model := &MongodbDatabase{
|
||||
Host: "localhost",
|
||||
Port: nil,
|
||||
Username: "testuser",
|
||||
Password: "testpass123",
|
||||
Database: "mydb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
|
||||
err := model.Validate()
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "port is required for standard connections")
|
||||
}
|
||||
|
||||
@@ -564,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 4: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
// Step 4: Discover schemas to grant privileges on
|
||||
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
|
||||
var rows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
AND schema_name = ANY($1::text[])
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
rows, err = tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get schemas: %w", err)
|
||||
}
|
||||
@@ -619,50 +630,197 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
|
||||
// Step 6: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to grant select on sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 7: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, baseUsername, baseUsername)
|
||||
// First, set default privileges for objects created by the current user
|
||||
// Use the already-filtered schemas list from Step 4
|
||||
for _, schema := range schemas {
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for tables in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
schema,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf(
|
||||
"failed to set default privileges for sequences in schema %s: %w",
|
||||
schema,
|
||||
err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 8: Verify user creation before committing
|
||||
// Step 8: Discover all roles that own objects in each schema
|
||||
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
|
||||
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
|
||||
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
|
||||
// Filter by IncludeSchemas if specified.
|
||||
type SchemaOwner struct {
|
||||
SchemaName string
|
||||
RoleName string
|
||||
}
|
||||
|
||||
var ownerRows pgx.Rows
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND n.nspname = ANY($1::text[])
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`, p.IncludeSchemas)
|
||||
} else {
|
||||
ownerRows, err = tx.Query(ctx, `
|
||||
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
|
||||
FROM pg_class c
|
||||
JOIN pg_namespace n ON c.relnamespace = n.oid
|
||||
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
|
||||
AND pg_get_userbyid(c.relowner) != current_user
|
||||
ORDER BY n.nspname, role_name
|
||||
`)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// Log warning but continue - this is a best-effort enhancement
|
||||
logger.Warn("Failed to query object owners for default privileges", "error", err)
|
||||
} else {
|
||||
var schemaOwners []SchemaOwner
|
||||
for ownerRows.Next() {
|
||||
var so SchemaOwner
|
||||
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
|
||||
ownerRows.Close()
|
||||
logger.Warn("Failed to scan schema owner", "error", err)
|
||||
break
|
||||
}
|
||||
schemaOwners = append(schemaOwners, so)
|
||||
}
|
||||
ownerRows.Close()
|
||||
|
||||
if err := ownerRows.Err(); err != nil {
|
||||
logger.Warn("Error iterating schema owners", "error", err)
|
||||
}
|
||||
|
||||
// Step 9: Set default privileges FOR ROLE for each object owner
|
||||
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
|
||||
// We log warnings but continue - user creation should succeed even if some roles can't be configured
|
||||
for _, so := range schemaOwners {
|
||||
// Try to set default privileges for tables
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (tables)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Try to set default privileges for sequences
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(
|
||||
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
|
||||
so.RoleName,
|
||||
so.SchemaName,
|
||||
baseUsername,
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to set default privileges for role (sequences)",
|
||||
"error",
|
||||
err,
|
||||
"role",
|
||||
so.RoleName,
|
||||
"schema",
|
||||
so.SchemaName,
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if len(schemaOwners) > 0 {
|
||||
logger.Info(
|
||||
"Set default privileges for existing object owners",
|
||||
"readonly_user",
|
||||
baseUsername,
|
||||
"owner_count",
|
||||
len(schemaOwners),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 10: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
|
||||
@@ -1319,6 +1319,346 @@ type PostgresContainer struct {
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create a second database user who will create tables
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Step 2: Grant the user_creator privileges to connect and create tables
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CREATE ON SCHEMA public TO "%s"`,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 2b: Create an initial table by user_creator so they become an object owner
|
||||
// This is important because our fix discovers existing object owners
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
initialTableName := fmt.Sprintf(
|
||||
"public.initial_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('initial_data');
|
||||
`, initialTableName, initialTableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
|
||||
}()
|
||||
|
||||
// Step 3: NOW create read-only user via Databasus (as admin)
|
||||
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
|
||||
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
|
||||
tableName := fmt.Sprintf(
|
||||
"public.future_table_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = userCreatorConn.Exec(fmt.Sprintf(`
|
||||
CREATE TABLE %s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
|
||||
`, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
// Step 5: Connect as read-only user and verify it can SELECT from the new table
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
2,
|
||||
count,
|
||||
"Read-only user should be able to SELECT from table created by different user",
|
||||
)
|
||||
|
||||
// Step 6: Verify read-only user cannot write to the table
|
||||
_, err = readonlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify pg_dump operations (LOCK TABLE) work
|
||||
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
|
||||
tx, err := readonlyConn.Begin()
|
||||
assert.NoError(t, err)
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
|
||||
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
|
||||
|
||||
err = tx.Commit()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
// Step 1: Create multiple schemas and tables
|
||||
_, err := container.DB.Exec(`
|
||||
DROP SCHEMA IF EXISTS included_schema CASCADE;
|
||||
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
|
||||
CREATE SCHEMA included_schema;
|
||||
CREATE SCHEMA excluded_schema;
|
||||
|
||||
CREATE TABLE public.public_table (id INT, data TEXT);
|
||||
INSERT INTO public.public_table VALUES (1, 'public_data');
|
||||
|
||||
CREATE TABLE included_schema.included_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
|
||||
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
|
||||
}()
|
||||
|
||||
// Step 2: Create a second user who owns tables in both included and excluded schemas
|
||||
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
|
||||
userCreatorPassword := "creator_password_123"
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
|
||||
}()
|
||||
|
||||
// Grant privileges to user_creator
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
|
||||
container.Database,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
|
||||
_, err = container.DB.Exec(fmt.Sprintf(
|
||||
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
|
||||
schema,
|
||||
userCreatorUsername,
|
||||
))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// User_creator creates tables in included and excluded schemas
|
||||
userCreatorDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
userCreatorUsername,
|
||||
userCreatorPassword,
|
||||
container.Database,
|
||||
)
|
||||
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
|
||||
assert.NoError(t, err)
|
||||
defer userCreatorConn.Close()
|
||||
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.user_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
|
||||
|
||||
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
|
||||
pgModel := createPostgresModel(container)
|
||||
pgModel.IncludeSchemas = []string{"public", "included_schema"}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
|
||||
ctx,
|
||||
logger,
|
||||
nil,
|
||||
uuid.New(),
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, readonlyUsername)
|
||||
assert.NotEmpty(t, readonlyPassword)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
|
||||
}()
|
||||
|
||||
// Step 4: Connect as read-only user
|
||||
readonlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
readonlyUsername,
|
||||
readonlyPassword,
|
||||
container.Database,
|
||||
)
|
||||
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readonlyConn.Close()
|
||||
|
||||
// Step 5: Verify read-only user CAN access included schemas
|
||||
var publicData string
|
||||
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "public_data", publicData)
|
||||
|
||||
var includedData string
|
||||
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "included_data", includedData)
|
||||
|
||||
var userIncludedData string
|
||||
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "user_included_data", userIncludedData)
|
||||
|
||||
// Step 6: Verify read-only user CANNOT access excluded schema
|
||||
var excludedData string
|
||||
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
err = readonlyConn.Get(
|
||||
&excludedData,
|
||||
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Step 7: Verify future tables in included schemas are accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE included_schema.future_table (id INT, data TEXT);
|
||||
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureData string
|
||||
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"future_data",
|
||||
futureData,
|
||||
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
|
||||
)
|
||||
|
||||
// Step 8: Verify future tables in excluded schema are NOT accessible
|
||||
_, err = userCreatorConn.Exec(`
|
||||
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
|
||||
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var futureExcludedData string
|
||||
err = readonlyConn.Get(
|
||||
&futureExcludedData,
|
||||
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(
|
||||
t,
|
||||
err.Error(),
|
||||
"permission denied",
|
||||
"Read-only user should NOT access tables in excluded schemas",
|
||||
)
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
|
||||
@@ -71,12 +71,13 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
|
||||
return &mongodb.MongodbDatabase{
|
||||
Version: tools.MongodbVersion7,
|
||||
Host: config.GetEnv().TestLocalhost,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: "root",
|
||||
Password: "rootpassword",
|
||||
Database: "testdb",
|
||||
AuthDatabase: "admin",
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ import (
|
||||
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
users_dto "databasus-backend/internal/features/users/dto"
|
||||
users_enums "databasus-backend/internal/features/users/enums"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
users_testing "databasus-backend/internal/features/users/testing"
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
@@ -358,7 +357,7 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup = createTestBackup(mysqlDB, owner)
|
||||
backup = createTestBackup(mysqlDB, storage)
|
||||
|
||||
request = restores_core.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysql.MysqlDatabase{
|
||||
@@ -610,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
backup := createTestBackup(database, storage)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
@@ -727,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
storage *storages.Storage,
|
||||
) *backups_core.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
StorageID: storage.ID,
|
||||
Status: backups_core.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
@@ -759,7 +748,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(
|
||||
if err := storage.SaveFile(
|
||||
context.Background(),
|
||||
fieldEncryptor,
|
||||
logger,
|
||||
|
||||
@@ -147,6 +147,26 @@ func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testin
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadb_WithExcludeEvents_EventsNotRestored(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMariadbBackupRestoreWithExcludeEventsForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
@@ -702,3 +722,145 @@ func updateMariadbDatabaseCredentialsViaAPI(
|
||||
|
||||
return &updatedDatabase
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreWithExcludeEventsForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMariadbContainer(mariadbVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE EVENT IF NOT EXISTS test_event
|
||||
ON SCHEDULE EVERY 1 DAY
|
||||
DO BEGIN
|
||||
INSERT INTO test_data (name, value) VALUES ('event_test', 999);
|
||||
END
|
||||
`)
|
||||
if err != nil {
|
||||
t.Skipf(
|
||||
"Skipping test: MariaDB version doesn't support events or event scheduler disabled: %v",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MariaDB Exclude Events Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseViaAPI(
|
||||
t, router, "MariaDB Exclude Events Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
database.Mariadb.IsExcludeEvents = true
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+user.Token,
|
||||
database,
|
||||
)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf(
|
||||
"Failed to update database with IsExcludeEvents. Status: %d, Body: %s",
|
||||
w.Code,
|
||||
w.Body.String(),
|
||||
)
|
||||
}
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups_core.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_no_events"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_core.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
var eventCount int
|
||||
err = newDB.Get(
|
||||
&eventCount,
|
||||
"SELECT COUNT(*) FROM information_schema.events WHERE event_schema = ? AND event_name = 'test_event'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
0,
|
||||
eventCount,
|
||||
"Event 'test_event' should NOT exist in restored database when IsExcludeEvents is true",
|
||||
)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
@@ -385,13 +385,14 @@ func createMongodbDatabaseViaAPI(
|
||||
Type: databases.DatabaseTypeMongodb,
|
||||
Mongodb: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthDatabase: authDatabase,
|
||||
Version: version,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
@@ -432,13 +433,14 @@ func createMongodbRestoreViaAPI(
|
||||
request := restores_core.RestoreBackupRequest{
|
||||
MongodbDatabase: &mongodbtypes.MongodbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Port: &port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: database,
|
||||
AuthDatabase: authDatabase,
|
||||
Version: version,
|
||||
IsHttps: false,
|
||||
IsSrv: false,
|
||||
CpuCount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -726,41 +726,28 @@ func Test_InviteUserToWorkspace_MembershipReceivedAfterSignUp(t *testing.T) {
|
||||
|
||||
assert.Equal(t, workspaces_dto.AddStatusInvited, inviteResponse.Status)
|
||||
|
||||
// 3. Sign up the invited user
|
||||
// 3. Sign up the invited user (now returns token directly)
|
||||
signUpRequest := users_dto.SignUpRequestDTO{
|
||||
Email: inviteEmail,
|
||||
Password: "testpassword123",
|
||||
Name: "Invited User",
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
var signInResponse users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signup",
|
||||
"",
|
||||
signUpRequest,
|
||||
http.StatusOK,
|
||||
)
|
||||
assert.Contains(t, string(resp.Body), "User created successfully")
|
||||
|
||||
// 4. Sign in the newly registered user
|
||||
signInRequest := users_dto.SignInRequestDTO{
|
||||
Email: inviteEmail,
|
||||
Password: "testpassword123",
|
||||
}
|
||||
|
||||
var signInResponse users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signInRequest,
|
||||
http.StatusOK,
|
||||
&signInResponse,
|
||||
)
|
||||
|
||||
// 5. Verify user is automatically added as member to workspace
|
||||
assert.NotEmpty(t, signInResponse.Token)
|
||||
assert.Equal(t, inviteEmail, signInResponse.Email)
|
||||
|
||||
// 4. Verify user is automatically added as member to workspace
|
||||
var membersResponse workspaces_dto.GetMembersResponseDTO
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
user_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
cloudflare_turnstile "databasus-backend/internal/util/cloudflare_turnstile"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -51,7 +52,7 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body users_dto.SignUpRequestDTO true "User signup data"
|
||||
// @Success 200
|
||||
// @Success 200 {object} users_dto.SignInResponseDTO
|
||||
// @Failure 400
|
||||
// @Router /users/signup [post]
|
||||
func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
@@ -61,13 +62,41 @@ func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.userService.SignUp(&request)
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err := c.userService.SignUp(&request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
|
||||
response, err := c.userService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// SignIn
|
||||
@@ -88,6 +117,28 @@ func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
|
||||
if !allowed {
|
||||
ctx.JSON(
|
||||
@@ -363,6 +414,28 @@ func (c *UserController) SendResetPasswordCode(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(
|
||||
request.Email,
|
||||
"reset-password",
|
||||
|
||||
@@ -27,7 +27,20 @@ func Test_SignUpUser_WithValidData_UserCreated(t *testing.T) {
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", request, http.StatusOK)
|
||||
var response users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signup",
|
||||
"",
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, response.Token)
|
||||
assert.NotEqual(t, uuid.Nil, response.UserID)
|
||||
assert.Equal(t, request.Email, response.Email)
|
||||
}
|
||||
|
||||
func Test_SignUpUser_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {
|
||||
|
||||
@@ -9,14 +9,16 @@ import (
|
||||
)
|
||||
|
||||
type SignUpRequestDTO struct {
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type SignInRequestDTO struct {
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type SignInResponseDTO struct {
|
||||
@@ -94,7 +96,8 @@ type OAuthCallbackResponseDTO struct {
|
||||
}
|
||||
|
||||
type SendResetPasswordCodeRequestDTO struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type ResetPasswordRequestDTO struct {
|
||||
|
||||
@@ -44,19 +44,19 @@ func (s *UserService) SetEmailSender(sender users_interfaces.EmailSender) {
|
||||
s.emailSender = sender
|
||||
}
|
||||
|
||||
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) (*users_models.User, error) {
|
||||
existingUser, err := s.userRepository.GetUserByEmail(request.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing user: %w", err)
|
||||
return nil, fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
|
||||
if existingUser != nil && existingUser.Status != users_enums.UserStatusInvited {
|
||||
return errors.New("user with this email already exists")
|
||||
return nil, errors.New("user with this email already exists")
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
hashedPasswordStr := string(hashedPassword)
|
||||
@@ -67,39 +67,45 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
existingUser.ID,
|
||||
hashedPasswordStr,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to set password: %w", err)
|
||||
return nil, fmt.Errorf("failed to set password: %w", err)
|
||||
}
|
||||
|
||||
if err := s.userRepository.UpdateUserStatus(
|
||||
existingUser.ID,
|
||||
users_enums.UserStatusActive,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to activate user: %w", err)
|
||||
return nil, fmt.Errorf("failed to activate user: %w", err)
|
||||
}
|
||||
|
||||
name := request.Name
|
||||
if err := s.userRepository.UpdateUserInfo(existingUser.ID, &name, nil); err != nil {
|
||||
return fmt.Errorf("failed to update name: %w", err)
|
||||
return nil, fmt.Errorf("failed to update name: %w", err)
|
||||
}
|
||||
|
||||
// Fetch updated user to ensure we have the latest data
|
||||
updatedUser, err := s.userRepository.GetUserByID(existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get updated user: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Invited user completed registration: %s", existingUser.Email),
|
||||
&existingUser.ID,
|
||||
fmt.Sprintf("Invited user completed registration: %s", updatedUser.Email),
|
||||
&updatedUser.ID,
|
||||
nil,
|
||||
)
|
||||
|
||||
return nil
|
||||
return updatedUser, nil
|
||||
}
|
||||
|
||||
// Get settings to check registration policy for new users
|
||||
settings, err := s.settingsService.GetSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get settings: %w", err)
|
||||
return nil, fmt.Errorf("failed to get settings: %w", err)
|
||||
}
|
||||
|
||||
// Check if external registrations are allowed
|
||||
if !settings.IsAllowExternalRegistrations {
|
||||
return errors.New("external registration is disabled")
|
||||
return nil, errors.New("external registration is disabled")
|
||||
}
|
||||
|
||||
user := &users_models.User{
|
||||
@@ -114,7 +120,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
}
|
||||
|
||||
if err := s.userRepository.CreateUser(user); err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
@@ -123,7 +129,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
nil,
|
||||
)
|
||||
|
||||
return nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SignIn(
|
||||
@@ -258,6 +264,7 @@ func (s *UserService) GenerateAccessToken(
|
||||
|
||||
return &users_dto.SignInResponseDTO{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Token: tokenString,
|
||||
}, nil
|
||||
}
|
||||
@@ -463,6 +470,178 @@ func (s *UserService) HandleGitHubOAuth(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) HandleGoogleOAuth(
|
||||
code, redirectUri string,
|
||||
) (*users_dto.OAuthCallbackResponseDTO, error) {
|
||||
return s.handleGoogleOAuthWithEndpoint(
|
||||
code,
|
||||
redirectUri,
|
||||
google.Endpoint,
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) SendResetPasswordCode(email string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Silently succeed for non-existent users to prevent enumeration attacks
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only active users can reset passwords
|
||||
if user.Status != users_enums.UserStatusActive {
|
||||
return errors.New("only active users can reset their password")
|
||||
}
|
||||
|
||||
// Check rate limiting - max 3 codes per hour
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check rate limit: %w", err)
|
||||
}
|
||||
|
||||
if recentCount >= 3 {
|
||||
return errors.New("too many password reset attempts, please try again later")
|
||||
}
|
||||
|
||||
// Generate 6-digit random code using crypto/rand for better randomness
|
||||
codeNum := make([]byte, 4)
|
||||
_, err = io.ReadFull(rand.Reader, codeNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate random code: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to uint32 and modulo to get 6 digits
|
||||
randomInt := uint32(
|
||||
codeNum[0],
|
||||
)<<24 | uint32(
|
||||
codeNum[1],
|
||||
)<<16 | uint32(
|
||||
codeNum[2],
|
||||
)<<8 | uint32(
|
||||
codeNum[3],
|
||||
)
|
||||
code := fmt.Sprintf("%06d", randomInt%1000000)
|
||||
|
||||
// Hash the code
|
||||
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash code: %w", err)
|
||||
}
|
||||
|
||||
// Store in database with 1 hour expiration
|
||||
resetCode := &users_models.PasswordResetCode{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
HashedCode: string(hashedCode),
|
||||
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
|
||||
IsUsed: false,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
|
||||
return fmt.Errorf("failed to create reset code: %w", err)
|
||||
}
|
||||
|
||||
// Send email with code
|
||||
if s.emailSender != nil {
|
||||
subject := "Password Reset Code"
|
||||
body := fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
|
||||
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
|
||||
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
You have requested to reset your password. Please use the following code to complete the password reset process:
|
||||
</p>
|
||||
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
|
||||
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
|
||||
</div>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
This code will expire in <strong>1 hour</strong>.
|
||||
</p>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
|
||||
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
|
||||
This is an automated message. Please do not reply to this email.
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, code)
|
||||
|
||||
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Password reset code sent to: %s", user.Email),
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPassword(email, code, newPassword string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return errors.New("user with this email does not exist")
|
||||
}
|
||||
|
||||
// Get valid reset code for user
|
||||
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
|
||||
if err != nil {
|
||||
return errors.New("invalid or expired reset code")
|
||||
}
|
||||
|
||||
// Verify code matches
|
||||
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
|
||||
if err != nil {
|
||||
return errors.New("invalid reset code")
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
|
||||
return fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Update user password
|
||||
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
"Password reset via email code",
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) handleGitHubOAuthWithEndpoint(
|
||||
code, redirectUri string,
|
||||
endpoint oauth2.Endpoint,
|
||||
@@ -529,17 +708,6 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
|
||||
return s.getOrCreateUserFromOAuth(oauthID, email, name, "github")
|
||||
}
|
||||
|
||||
func (s *UserService) HandleGoogleOAuth(
|
||||
code, redirectUri string,
|
||||
) (*users_dto.OAuthCallbackResponseDTO, error) {
|
||||
return s.handleGoogleOAuthWithEndpoint(
|
||||
code,
|
||||
redirectUri,
|
||||
google.Endpoint,
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) handleGoogleOAuthWithEndpoint(
|
||||
code, redirectUri string,
|
||||
endpoint oauth2.Endpoint,
|
||||
@@ -805,164 +973,3 @@ func (s *UserService) fetchGitHubPrimaryEmail(
|
||||
|
||||
return "", errors.New("github account has no accessible email")
|
||||
}
|
||||
|
||||
func (s *UserService) SendResetPasswordCode(email string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Silently succeed for non-existent users to prevent enumeration attacks
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only active users can reset passwords
|
||||
if user.Status != users_enums.UserStatusActive {
|
||||
return errors.New("only active users can reset their password")
|
||||
}
|
||||
|
||||
// Check rate limiting - max 3 codes per hour
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check rate limit: %w", err)
|
||||
}
|
||||
|
||||
if recentCount >= 3 {
|
||||
return errors.New("too many password reset attempts, please try again later")
|
||||
}
|
||||
|
||||
// Generate 6-digit random code using crypto/rand for better randomness
|
||||
codeNum := make([]byte, 4)
|
||||
_, err = io.ReadFull(rand.Reader, codeNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate random code: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to uint32 and modulo to get 6 digits
|
||||
randomInt := uint32(
|
||||
codeNum[0],
|
||||
)<<24 | uint32(
|
||||
codeNum[1],
|
||||
)<<16 | uint32(
|
||||
codeNum[2],
|
||||
)<<8 | uint32(
|
||||
codeNum[3],
|
||||
)
|
||||
code := fmt.Sprintf("%06d", randomInt%1000000)
|
||||
|
||||
// Hash the code
|
||||
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash code: %w", err)
|
||||
}
|
||||
|
||||
// Store in database with 1 hour expiration
|
||||
resetCode := &users_models.PasswordResetCode{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
HashedCode: string(hashedCode),
|
||||
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
|
||||
IsUsed: false,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
|
||||
return fmt.Errorf("failed to create reset code: %w", err)
|
||||
}
|
||||
|
||||
// Send email with code
|
||||
if s.emailSender != nil {
|
||||
subject := "Password Reset Code"
|
||||
body := fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
|
||||
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
|
||||
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
You have requested to reset your password. Please use the following code to complete the password reset process:
|
||||
</p>
|
||||
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
|
||||
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
|
||||
</div>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
This code will expire in <strong>1 hour</strong>.
|
||||
</p>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
|
||||
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
|
||||
This is an automated message. Please do not reply to this email.
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, code)
|
||||
|
||||
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Password reset code sent to: %s", user.Email),
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPassword(email, code, newPassword string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return errors.New("user with this email does not exist")
|
||||
}
|
||||
|
||||
// Get valid reset code for user
|
||||
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
|
||||
if err != nil {
|
||||
return errors.New("invalid or expired reset code")
|
||||
}
|
||||
|
||||
// Verify code matches
|
||||
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
|
||||
if err != nil {
|
||||
return errors.New("invalid reset code")
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
|
||||
return fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Update user password
|
||||
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
"Password reset via email code",
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package cloudflare_turnstile
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CloudflareTurnstileService struct {
|
||||
secretKey string
|
||||
siteKey string
|
||||
}
|
||||
|
||||
type cloudflareTurnstileResponse struct {
|
||||
Success bool `json:"success"`
|
||||
ChallengeTS time.Time `json:"challenge_ts"`
|
||||
Hostname string `json:"hostname"`
|
||||
ErrorCodes []string `json:"error-codes"`
|
||||
}
|
||||
|
||||
const cloudflareTurnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
func (s *CloudflareTurnstileService) IsEnabled() bool {
|
||||
return s.secretKey != ""
|
||||
}
|
||||
|
||||
func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool, error) {
|
||||
if !s.IsEnabled() {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return false, errors.New("cloudflare Turnstile token is required")
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", s.secretKey)
|
||||
formData.Set("response", token)
|
||||
formData.Set("remoteip", remoteIP)
|
||||
|
||||
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read Cloudflare Turnstile response: %w", err)
|
||||
}
|
||||
|
||||
var turnstileResp cloudflareTurnstileResponse
|
||||
if err := json.Unmarshal(body, &turnstileResp); err != nil {
|
||||
return false, fmt.Errorf("failed to parse Cloudflare Turnstile response: %w", err)
|
||||
}
|
||||
|
||||
if !turnstileResp.Success {
|
||||
return false, fmt.Errorf(
|
||||
"cloudflare Turnstile verification failed: %v",
|
||||
turnstileResp.ErrorCodes,
|
||||
)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
14
backend/internal/util/cloudflare_turnstile/di.go
Normal file
14
backend/internal/util/cloudflare_turnstile/di.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cloudflare_turnstile
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
)
|
||||
|
||||
var cloudflareTurnstileService = &CloudflareTurnstileService{
|
||||
config.GetEnv().CloudflareTurnstileSecretKey,
|
||||
config.GetEnv().CloudflareTurnstileSiteKey,
|
||||
}
|
||||
|
||||
func GetCloudflareTurnstileService() *CloudflareTurnstileService {
|
||||
return cloudflareTurnstileService
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ALTER COLUMN port DROP NOT NULL;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ADD COLUMN is_srv BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases DROP COLUMN is_srv;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mongodb_databases ALTER COLUMN port SET NOT NULL;
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mariadb_databases
|
||||
ADD COLUMN IF NOT EXISTS is_exclude_events BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE mariadb_databases
|
||||
DROP COLUMN IF EXISTS is_exclude_events;
|
||||
-- +goose StatementEnd
|
||||
@@ -2,4 +2,5 @@ MODE=development
|
||||
VITE_GITHUB_CLIENT_ID=
|
||||
VITE_GOOGLE_CLIENT_ID=
|
||||
VITE_IS_EMAIL_CONFIGURED=false
|
||||
VITE_IS_CLOUD=false
|
||||
VITE_IS_CLOUD=false
|
||||
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
@@ -3,6 +3,7 @@ interface RuntimeConfig {
|
||||
GITHUB_CLIENT_ID?: string;
|
||||
GOOGLE_CLIENT_ID?: string;
|
||||
IS_EMAIL_CONFIGURED?: string;
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
|
||||
}
|
||||
|
||||
declare global {
|
||||
@@ -39,6 +40,11 @@ export const IS_EMAIL_CONFIGURED =
|
||||
window.__RUNTIME_CONFIG__?.IS_EMAIL_CONFIGURED === 'true' ||
|
||||
import.meta.env.VITE_IS_EMAIL_CONFIGURED === 'true';
|
||||
|
||||
export const CLOUDFLARE_TURNSTILE_SITE_KEY =
|
||||
window.__RUNTIME_CONFIG__?.CLOUDFLARE_TURNSTILE_SITE_KEY ||
|
||||
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
|
||||
'';
|
||||
|
||||
export function getOAuthRedirectUri(): string {
|
||||
return `${window.location.origin}/auth/callback`;
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('mydb');
|
||||
expect(result.authDatabase).toBe('admin');
|
||||
expect(result.useTls).toBe(false);
|
||||
expect(result.isSrv).toBe(false);
|
||||
});
|
||||
|
||||
it('should parse connection string without database', () => {
|
||||
@@ -46,6 +47,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('');
|
||||
expect(result.authDatabase).toBe('admin');
|
||||
expect(result.useTls).toBe(false);
|
||||
expect(result.isSrv).toBe(false);
|
||||
});
|
||||
|
||||
it('should default port to 27017 when not specified', () => {
|
||||
@@ -107,6 +109,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.password).toBe('atlaspass');
|
||||
expect(result.database).toBe('mydb');
|
||||
expect(result.useTls).toBe(true); // SRV connections use TLS by default
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
|
||||
it('should parse mongodb+srv:// without database', () => {
|
||||
@@ -119,6 +122,7 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.host).toBe('cluster0.abc123.mongodb.net');
|
||||
expect(result.database).toBe('');
|
||||
expect(result.useTls).toBe(true);
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -314,13 +318,15 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.format).toBe('key-value');
|
||||
});
|
||||
|
||||
it('should return error for key-value format missing password', () => {
|
||||
const result = expectError(
|
||||
it('should allow missing password in key-value format (returns empty password)', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('host=localhost database=mydb user=admin'),
|
||||
);
|
||||
|
||||
expect(result.error).toContain('Password');
|
||||
expect(result.format).toBe('key-value');
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -351,12 +357,15 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.error).toContain('Username');
|
||||
});
|
||||
|
||||
it('should return error for missing password in URI', () => {
|
||||
const result = expectError(
|
||||
it('should allow missing password in URI (returns empty password)', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.error).toContain('Password');
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should return error for mysql:// format (wrong database type)', () => {
|
||||
@@ -446,4 +455,67 @@ describe('MongodbConnectionStringParser', () => {
|
||||
expect(result.database).toBe('');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Password Placeholder Handling', () => {
|
||||
it('should treat <db_password> placeholder as empty password in URI format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user:<db_password>@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should treat <password> placeholder as empty password in URI format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse('mongodb://user:<password>@host:27017/db'),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('host');
|
||||
expect(result.database).toBe('db');
|
||||
});
|
||||
|
||||
it('should treat <db_password> placeholder as empty password in SRV format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'mongodb+srv://user:<db_password>@cluster0.mongodb.net/db',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.username).toBe('user');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.host).toBe('cluster0.mongodb.net');
|
||||
expect(result.isSrv).toBe(true);
|
||||
});
|
||||
|
||||
it('should treat <db_password> placeholder as empty password in key-value format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'host=localhost database=mydb user=admin password=<db_password>',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
|
||||
it('should treat <password> placeholder as empty password in key-value format', () => {
|
||||
const result = expectSuccess(
|
||||
MongodbConnectionStringParser.parse(
|
||||
'host=localhost database=mydb user=admin password=<password>',
|
||||
),
|
||||
);
|
||||
|
||||
expect(result.host).toBe('localhost');
|
||||
expect(result.username).toBe('admin');
|
||||
expect(result.password).toBe('');
|
||||
expect(result.database).toBe('mydb');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,6 +6,7 @@ export type ParseResult = {
|
||||
database: string;
|
||||
authDatabase: string;
|
||||
useTls: boolean;
|
||||
isSrv: boolean;
|
||||
};
|
||||
|
||||
export type ParseError = {
|
||||
@@ -63,7 +64,8 @@ export class MongodbConnectionStringParser {
|
||||
const host = url.hostname;
|
||||
const port = url.port ? parseInt(url.port, 10) : isSrv ? 27017 : 27017;
|
||||
const username = decodeURIComponent(url.username);
|
||||
const password = decodeURIComponent(url.password);
|
||||
const rawPassword = decodeURIComponent(url.password);
|
||||
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword;
|
||||
const database = decodeURIComponent(url.pathname.slice(1));
|
||||
const authDatabase = this.getAuthSource(url.search) || 'admin';
|
||||
const useTls = isSrv ? true : this.checkTlsMode(url.search);
|
||||
@@ -76,10 +78,6 @@ export class MongodbConnectionStringParser {
|
||||
return { error: 'Username is missing from connection string' };
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
return { error: 'Password is missing from connection string' };
|
||||
}
|
||||
|
||||
return {
|
||||
host,
|
||||
port,
|
||||
@@ -88,6 +86,7 @@ export class MongodbConnectionStringParser {
|
||||
database: database || '',
|
||||
authDatabase,
|
||||
useTls,
|
||||
isSrv,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
@@ -114,7 +113,8 @@ export class MongodbConnectionStringParser {
|
||||
const port = params['port'];
|
||||
const database = params['database'] || params['dbname'] || params['db'];
|
||||
const username = params['user'] || params['username'];
|
||||
const password = params['password'];
|
||||
const rawPassword = params['password'];
|
||||
const password = this.isPasswordPlaceholder(rawPassword) ? '' : rawPassword || '';
|
||||
const authDatabase = params['authSource'] || params['authDatabase'] || 'admin';
|
||||
const tls = params['tls'] || params['ssl'];
|
||||
|
||||
@@ -132,13 +132,6 @@ export class MongodbConnectionStringParser {
|
||||
};
|
||||
}
|
||||
|
||||
if (!password) {
|
||||
return {
|
||||
error: 'Password is missing from connection string. Use password=yourpassword',
|
||||
format: 'key-value',
|
||||
};
|
||||
}
|
||||
|
||||
const useTls = this.isTlsEnabled(tls);
|
||||
|
||||
return {
|
||||
@@ -149,6 +142,7 @@ export class MongodbConnectionStringParser {
|
||||
database: database || '',
|
||||
authDatabase,
|
||||
useTls,
|
||||
isSrv: false,
|
||||
};
|
||||
} catch (e) {
|
||||
return {
|
||||
@@ -191,4 +185,11 @@ export class MongodbConnectionStringParser {
|
||||
const enabledValues = ['true', 'yes', '1'];
|
||||
return enabledValues.includes(lowercased);
|
||||
}
|
||||
|
||||
private static isPasswordPlaceholder(password: string | null | undefined): boolean {
|
||||
if (!password) return false;
|
||||
|
||||
const trimmed = password.trim();
|
||||
return trimmed === '<db_password>' || trimmed === '<password>';
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,5 +10,6 @@ export interface MongodbDatabase {
|
||||
database: string;
|
||||
authDatabase: string;
|
||||
isHttps: boolean;
|
||||
isSrv: boolean;
|
||||
cpuCount: number;
|
||||
}
|
||||
|
||||
@@ -31,10 +31,18 @@ const notifyAuthListeners = () => {
|
||||
};
|
||||
|
||||
export const userApi = {
|
||||
async signUp(signUpRequest: SignUpRequest) {
|
||||
async signUp(signUpRequest: SignUpRequest): Promise<SignInResponse> {
|
||||
const requestOptions: RequestOptions = new RequestOptions();
|
||||
requestOptions.setBody(JSON.stringify(signUpRequest));
|
||||
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/users/signup`, requestOptions);
|
||||
|
||||
return apiHelper
|
||||
.fetchPostJson(`${getApplicationServer()}/api/v1/users/signup`, requestOptions)
|
||||
.then((response: unknown): SignInResponse => {
|
||||
const typedResponse = response as SignInResponse;
|
||||
saveAuthorizedData(typedResponse.token, typedResponse.userId);
|
||||
notifyAuthListeners();
|
||||
return typedResponse;
|
||||
});
|
||||
},
|
||||
|
||||
async signIn(signInRequest: SignInRequest): Promise<SignInResponse> {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export interface SendResetPasswordCodeRequest {
|
||||
email: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export interface SignInRequest {
|
||||
email: string;
|
||||
password: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -2,4 +2,5 @@ export interface SignUpRequest {
|
||||
email: string;
|
||||
password: string;
|
||||
name: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
const [isTestingConnection, setIsTestingConnection] = useState(false);
|
||||
const [isConnectionFailed, setIsConnectionFailed] = useState(false);
|
||||
|
||||
const hasAdvancedValues = !!database.mongodb?.authDatabase;
|
||||
const hasAdvancedValues = !!database.mongodb?.authDatabase || !!database.mongodb?.isSrv;
|
||||
const [isShowAdvanced, setShowAdvanced] = useState(hasAdvancedValues);
|
||||
|
||||
const parseFromClipboard = async () => {
|
||||
@@ -75,17 +75,29 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
host: result.host,
|
||||
port: result.port,
|
||||
username: result.username,
|
||||
password: result.password,
|
||||
password: result.password || '',
|
||||
database: result.database,
|
||||
authDatabase: result.authDatabase,
|
||||
isHttps: result.useTls,
|
||||
isSrv: result.isSrv,
|
||||
cpuCount: 1,
|
||||
},
|
||||
};
|
||||
|
||||
if (result.isSrv) {
|
||||
setShowAdvanced(true);
|
||||
}
|
||||
|
||||
setEditingDatabase(updatedDatabase);
|
||||
setIsConnectionTested(false);
|
||||
message.success('Connection string parsed successfully');
|
||||
|
||||
if (!result.password) {
|
||||
message.warning(
|
||||
'Connection string parsed successfully. Please enter the password manually.',
|
||||
);
|
||||
} else {
|
||||
message.success('Connection string parsed successfully');
|
||||
}
|
||||
} catch {
|
||||
message.error('Failed to read clipboard. Please check browser permissions.');
|
||||
}
|
||||
@@ -156,9 +168,11 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
|
||||
if (!editingDatabase) return null;
|
||||
|
||||
const isSrvConnection = editingDatabase.mongodb?.isSrv || false;
|
||||
|
||||
let isAllFieldsFilled = true;
|
||||
if (!editingDatabase.mongodb?.host) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.port) isAllFieldsFilled = false;
|
||||
if (!isSrvConnection && !editingDatabase.mongodb?.port) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.username) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.id && !editingDatabase.mongodb?.password) isAllFieldsFilled = false;
|
||||
if (!editingDatabase.mongodb?.database) isAllFieldsFilled = false;
|
||||
@@ -220,25 +234,27 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Port</div>
|
||||
<InputNumber
|
||||
type="number"
|
||||
value={editingDatabase.mongodb?.port}
|
||||
onChange={(e) => {
|
||||
if (!editingDatabase.mongodb || e === null) return;
|
||||
{!isSrvConnection && (
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Port</div>
|
||||
<InputNumber
|
||||
type="number"
|
||||
value={editingDatabase.mongodb?.port}
|
||||
onChange={(e) => {
|
||||
if (!editingDatabase.mongodb || e === null) return;
|
||||
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, port: e },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
className="max-w-[200px] grow"
|
||||
placeholder="27017"
|
||||
/>
|
||||
</div>
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, port: e },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
className="max-w-[200px] grow"
|
||||
placeholder="27017"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Username</div>
|
||||
@@ -366,6 +382,31 @@ export const EditMongoDbSpecificDataComponent = ({
|
||||
|
||||
{isShowAdvanced && (
|
||||
<>
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Use SRV connection</div>
|
||||
<div className="flex items-center">
|
||||
<Switch
|
||||
checked={editingDatabase.mongodb?.isSrv || false}
|
||||
onChange={(checked) => {
|
||||
if (!editingDatabase.mongodb) return;
|
||||
|
||||
setEditingDatabase({
|
||||
...editingDatabase,
|
||||
mongodb: { ...editingDatabase.mongodb, isSrv: checked },
|
||||
});
|
||||
setIsConnectionTested(false);
|
||||
}}
|
||||
size="small"
|
||||
/>
|
||||
<Tooltip
|
||||
className="cursor-pointer"
|
||||
title="Enable for MongoDB Atlas SRV connections (mongodb+srv://). Port is not required for SRV connections."
|
||||
>
|
||||
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mb-1 flex w-full items-center">
|
||||
<div className="min-w-[150px]">Auth database</div>
|
||||
<Input
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
|
||||
interface RequestResetPasswordComponentProps {
|
||||
onSwitchToSignIn?: () => void;
|
||||
@@ -20,6 +23,8 @@ export function RequestResetPasswordComponent({
|
||||
const [error, setError] = useState('');
|
||||
const [successMessage, setSuccessMessage] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateEmail = (): boolean => {
|
||||
if (!email) {
|
||||
setEmailError(true);
|
||||
@@ -42,7 +47,10 @@ export function RequestResetPasswordComponent({
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const response = await userApi.sendResetPasswordCode({ email });
|
||||
const response = await userApi.sendResetPasswordCode({
|
||||
email,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
setSuccessMessage(response.message);
|
||||
|
||||
// After successful code send, switch to reset password form
|
||||
@@ -53,6 +61,7 @@ export function RequestResetPasswordComponent({
|
||||
}, 2000);
|
||||
} catch (e) {
|
||||
setError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
@@ -84,6 +93,8 @@ export function RequestResetPasswordComponent({
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID, IS_EMAIL_CONFIGURED } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
@@ -29,6 +32,8 @@ export function SignInComponent({
|
||||
|
||||
const [signInError, setSignInError] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateFieldsForSignIn = (): boolean => {
|
||||
if (!email) {
|
||||
setEmailError(true);
|
||||
@@ -59,9 +64,11 @@ export function SignInComponent({
|
||||
await userApi.signIn({
|
||||
email,
|
||||
password,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
} catch (e) {
|
||||
setSignInError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
@@ -119,6 +126,8 @@ export function SignInComponent({
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { App, Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
@@ -31,6 +34,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
|
||||
const [signUpError, setSignUpError] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateFieldsForSignUp = (): boolean => {
|
||||
if (!name || name.trim() === '') {
|
||||
setNameError(true);
|
||||
@@ -85,10 +90,11 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
email,
|
||||
password,
|
||||
name,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
await userApi.signIn({ email, password });
|
||||
} catch (e) {
|
||||
setSignUpError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,6 +179,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import dayjs from 'dayjs';
|
||||
import relativeTime from 'dayjs/plugin/relativeTime';
|
||||
import utc from 'dayjs/plugin/utc';
|
||||
import { StrictMode } from 'react';
|
||||
import { createRoot } from 'react-dom/client';
|
||||
|
||||
import './index.css';
|
||||
@@ -11,8 +10,4 @@ import App from './App.tsx';
|
||||
dayjs.extend(utc);
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
createRoot(document.getElementById('root')!).render(
|
||||
<StrictMode>
|
||||
<App />
|
||||
</StrictMode>,
|
||||
);
|
||||
createRoot(document.getElementById('root')!).render(<App />);
|
||||
|
||||
116
frontend/src/shared/hooks/useCloudflareTurnstile.ts
Normal file
116
frontend/src/shared/hooks/useCloudflareTurnstile.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
turnstile?: {
|
||||
render: (
|
||||
container: string | HTMLElement,
|
||||
options: {
|
||||
sitekey: string;
|
||||
callback: (token: string) => void;
|
||||
'error-callback'?: () => void;
|
||||
'expired-callback'?: () => void;
|
||||
theme?: 'light' | 'dark' | 'auto';
|
||||
size?: 'normal' | 'compact' | 'flexible';
|
||||
appearance?: 'always' | 'execute' | 'interaction-only';
|
||||
},
|
||||
) => string;
|
||||
reset: (widgetId: string) => void;
|
||||
remove: (widgetId: string) => void;
|
||||
getResponse: (widgetId: string) => string | undefined;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
interface UseCloudflareTurnstileReturn {
|
||||
containerRef: React.RefObject<HTMLDivElement | null>;
|
||||
token: string | undefined;
|
||||
resetCloudflareTurnstile: () => void;
|
||||
}
|
||||
|
||||
const loadCloudflareTurnstileScript = (): Promise<void> => {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
if (document.querySelector('script[src*="turnstile"]')) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
const script = document.createElement('script');
|
||||
script.src = 'https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit';
|
||||
script.async = true;
|
||||
script.defer = true;
|
||||
script.onload = () => resolve();
|
||||
script.onerror = () => reject(new Error('Failed to load Cloudflare Turnstile'));
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
};
|
||||
|
||||
export function useCloudflareTurnstile(): UseCloudflareTurnstileReturn {
|
||||
const [token, setToken] = useState<string | undefined>(undefined);
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const widgetIdRef = useRef<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY || !containerRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadCloudflareTurnstileScript()
|
||||
.then(() => {
|
||||
if (!window.turnstile || !containerRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const widgetId = window.turnstile.render(containerRef.current, {
|
||||
sitekey: CLOUDFLARE_TURNSTILE_SITE_KEY,
|
||||
callback: (receivedToken: string) => {
|
||||
setToken(receivedToken);
|
||||
},
|
||||
'error-callback': () => {
|
||||
setToken(undefined);
|
||||
},
|
||||
'expired-callback': () => {
|
||||
setToken(undefined);
|
||||
},
|
||||
theme: 'auto',
|
||||
size: 'normal',
|
||||
appearance: 'execute',
|
||||
});
|
||||
|
||||
widgetIdRef.current = widgetId;
|
||||
} catch (error) {
|
||||
console.error('Failed to render Cloudflare Turnstile widget:', error);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to load Cloudflare Turnstile:', error);
|
||||
});
|
||||
|
||||
return () => {
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.remove(widgetIdRef.current);
|
||||
widgetIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const resetCloudflareTurnstile = () => {
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.reset(widgetIdRef.current);
|
||||
setToken(undefined);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
token,
|
||||
resetCloudflareTurnstile,
|
||||
};
|
||||
}
|
||||
17
frontend/src/shared/ui/CloudflareTurnstileWidget.tsx
Normal file
17
frontend/src/shared/ui/CloudflareTurnstileWidget.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import { type JSX } from 'react';
|
||||
|
||||
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
|
||||
|
||||
interface CloudflareTurnstileWidgetProps {
|
||||
containerRef: React.RefObject<HTMLDivElement | null>;
|
||||
}
|
||||
|
||||
export function CloudflareTurnstileWidget({
|
||||
containerRef,
|
||||
}: CloudflareTurnstileWidgetProps): JSX.Element | null {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <div ref={containerRef} className="mb-3" />;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
export { CloudflareTurnstileWidget } from './CloudflareTurnstileWidget';
|
||||
export { ConfirmationComponent } from './ConfirmationComponent';
|
||||
export { StarButtonComponent } from './StarButtonComponent';
|
||||
export { ThemeToggleComponent } from './ThemeToggleComponent';
|
||||
|
||||
Reference in New Issue
Block a user