mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f1d80245f | ||
|
|
16a29cf458 | ||
|
|
43e04500ac | ||
|
|
cee3022f85 | ||
|
|
f46d92c480 | ||
|
|
10677238d7 | ||
|
|
2553203fcf | ||
|
|
7b05bd8000 | ||
|
|
8d45728f73 | ||
|
|
c70ad82c95 |
@@ -268,7 +268,8 @@ window.__RUNTIME_CONFIG__ = {
|
||||
IS_CLOUD: '\${IS_CLOUD:-false}',
|
||||
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
|
||||
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED'
|
||||
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
|
||||
};
|
||||
JSEOF
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
[](https://www.mongodb.com/)
|
||||
<br />
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://hub.docker.com/r/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
[](https://github.com/databasus/databasus)
|
||||
@@ -31,8 +31,6 @@
|
||||
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
|
||||
|
||||
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
@@ -11,6 +11,9 @@ VICTORIA_LOGS_PASSWORD=devpassword
|
||||
# tests
|
||||
TEST_LOCALHOST=localhost
|
||||
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
|
||||
# cloudflare turnstile
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
CLOUDFLARE_TURNSTILE_SECRET_KEY=
|
||||
# db
|
||||
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
|
||||
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable
|
||||
|
||||
@@ -104,6 +104,10 @@ type EnvVariables struct {
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// Cloudflare Turnstile
|
||||
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
|
||||
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
|
||||
@@ -196,7 +196,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
|
||||
backupMetadata, err := n.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -263,7 +263,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID.String()); deleteErr != nil {
|
||||
n.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
|
||||
@@ -79,7 +79,7 @@ func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(c.fieldEncryptor, backup.ID)
|
||||
err = storage.DeleteFile(c.fieldEncryptor, backup.ID.String())
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
|
||||
@@ -25,24 +25,24 @@ var backupRepository = &backups_core.BackupRepository{}
|
||||
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
|
||||
|
||||
var backupCleaner = &BackupCleaner{
|
||||
backupRepository: backupRepository,
|
||||
storageService: storages.GetStorageService(),
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
logger: logger.GetLogger(),
|
||||
backupRemoveListeners: []backups_core.BackupRemoveListener{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
backupRepository,
|
||||
storages.GetStorageService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger.GetLogger(),
|
||||
[]backups_core.BackupRemoveListener{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupNodesRegistry = &BackupNodesRegistry{
|
||||
client: cache_utils.GetValkeyClient(),
|
||||
logger: logger.GetLogger(),
|
||||
timeout: cache_utils.DefaultCacheTimeout,
|
||||
pubsubBackups: cache_utils.NewPubSubManager(),
|
||||
pubsubCompletions: cache_utils.NewPubSubManager(),
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
cache_utils.GetValkeyClient(),
|
||||
logger.GetLogger(),
|
||||
cache_utils.DefaultCacheTimeout,
|
||||
cache_utils.NewPubSubManager(),
|
||||
cache_utils.NewPubSubManager(),
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func getNodeID() uuid.UUID {
|
||||
@@ -50,34 +50,35 @@ func getNodeID() uuid.UUID {
|
||||
}
|
||||
|
||||
var backuperNode = &BackuperNode{
|
||||
databaseService: databases.GetDatabaseService(),
|
||||
fieldEncryptor: encryption.GetFieldEncryptor(),
|
||||
workspaceService: workspaces_services.GetWorkspaceService(),
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
storageService: storages.GetStorageService(),
|
||||
notificationSender: notifiers.GetNotifierService(),
|
||||
backupCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
logger: logger.GetLogger(),
|
||||
createBackupUseCase: usecases.GetCreateBackupUsecase(),
|
||||
nodeID: getNodeID(),
|
||||
lastHeartbeat: time.Time{},
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
databases.GetDatabaseService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
storages.GetStorageService(),
|
||||
notifiers.GetNotifierService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
logger.GetLogger(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
getNodeID(),
|
||||
time.Time{},
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
var backupsScheduler = &BackupsScheduler{
|
||||
backupRepository: backupRepository,
|
||||
backupConfigService: backups_config.GetBackupConfigService(),
|
||||
taskCancelManager: taskCancelManager,
|
||||
backupNodesRegistry: backupNodesRegistry,
|
||||
lastBackupTime: time.Now().UTC(),
|
||||
logger: logger.GetLogger(),
|
||||
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode: backuperNode,
|
||||
runOnce: sync.Once{},
|
||||
hasRun: atomic.Bool{},
|
||||
backupRepository,
|
||||
backups_config.GetBackupConfigService(),
|
||||
taskCancelManager,
|
||||
backupNodesRegistry,
|
||||
databases.GetDatabaseService(),
|
||||
time.Now().UTC(),
|
||||
logger.GetLogger(),
|
||||
make(map[uuid.UUID]BackupToNodeRelation),
|
||||
backuperNode,
|
||||
sync.Once{},
|
||||
atomic.Bool{},
|
||||
}
|
||||
|
||||
func GetBackupsScheduler() *BackupsScheduler {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
@@ -32,7 +33,7 @@ type CreateFailedBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -46,7 +47,7 @@ type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -65,7 +66,7 @@ type CreateLargeBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateLargeBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -84,7 +85,7 @@ type CreateProgressiveBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateProgressiveBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -124,7 +125,7 @@ type CreateMediumBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateMediumBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -152,7 +153,7 @@ func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
|
||||
|
||||
func (m *MockTrackingBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -162,7 +163,7 @@ func (m *MockTrackingBackupUsecase) Execute(
|
||||
|
||||
// Send backup ID to channel (non-blocking)
|
||||
select {
|
||||
case m.calledBackupIDs <- backupID:
|
||||
case m.calledBackupIDs <- backup.ID:
|
||||
default:
|
||||
}
|
||||
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"databasus-backend/internal/config"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -27,6 +29,7 @@ type BackupsScheduler struct {
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
taskCancelManager *task_cancellation.TaskCancelManager
|
||||
backupNodesRegistry *BackupNodesRegistry
|
||||
databaseService *databases.DatabaseService
|
||||
|
||||
lastBackupTime time.Time
|
||||
logger *slog.Logger
|
||||
@@ -113,28 +116,28 @@ func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
|
||||
return len(nodes) > 0
|
||||
}
|
||||
|
||||
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
|
||||
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get backup config by database ID", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if backupConfig.StorageID == nil {
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
|
||||
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
|
||||
return
|
||||
}
|
||||
|
||||
// Check for existing in-progress backups
|
||||
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
|
||||
databaseID,
|
||||
database.ID,
|
||||
backups_core.BackupStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
s.logger.Error(
|
||||
"Failed to check for in-progress backups",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
database.ID,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
@@ -145,7 +148,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
s.logger.Warn(
|
||||
"Backup already in progress for database, skipping new backup",
|
||||
"databaseId",
|
||||
databaseID,
|
||||
database.ID,
|
||||
"existingBackupId",
|
||||
inProgressBackups[0].ID,
|
||||
)
|
||||
@@ -164,13 +167,22 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Println("make backup")
|
||||
backupID := uuid.New()
|
||||
timestamp := time.Now().UTC()
|
||||
|
||||
backup := &backups_core.Backup{
|
||||
ID: backupID,
|
||||
FileName: fmt.Sprintf(
|
||||
"%s-%s-%s",
|
||||
files_utils.SanitizeFilename(database.Name),
|
||||
timestamp.Format("20060102-150405"),
|
||||
backupID.String(),
|
||||
),
|
||||
DatabaseID: backupConfig.DatabaseID,
|
||||
StorageID: *backupConfig.StorageID,
|
||||
Status: backups_core.BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
CreatedAt: timestamp,
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
@@ -224,8 +236,8 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = relation
|
||||
} else {
|
||||
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
|
||||
NodeID: *leastBusyNodeID,
|
||||
BackupsIDs: []uuid.UUID{backup.ID},
|
||||
*leastBusyNodeID,
|
||||
[]uuid.UUID{backup.ID},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,7 +341,13 @@ func (s *BackupsScheduler) runPendingBackups() error {
|
||||
backupConfig.BackupInterval.Interval,
|
||||
)
|
||||
|
||||
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
|
||||
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to get database by ID", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.StartBackup(database, remainedBackupTryCount == 1)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
@@ -492,7 +492,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Scheduler assigns backup to mock node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -595,7 +595,7 @@ func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Start a backup and assign it to the node
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
backups, err := backupRepository.FindByDatabaseID(database.ID)
|
||||
@@ -892,7 +892,7 @@ func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
// Wait for backup to complete
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -995,7 +995,7 @@ func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
|
||||
t.Logf("Initial active tasks: %d", initialActiveTasks)
|
||||
|
||||
// Start backup
|
||||
scheduler.StartBackup(database.ID, false)
|
||||
scheduler.StartBackup(database, false)
|
||||
|
||||
// Wait for backup to fail
|
||||
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
|
||||
@@ -1088,7 +1088,7 @@ func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Try to start a new backup - should be skipped
|
||||
GetBackupsScheduler().StartBackup(database.ID, false)
|
||||
GetBackupsScheduler().StartBackup(database, false)
|
||||
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
@@ -1268,10 +1268,10 @@ func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCa
|
||||
|
||||
// Start 2 backups simultaneously
|
||||
t.Log("Starting backup for database1")
|
||||
scheduler.StartBackup(database1.ID, false)
|
||||
scheduler.StartBackup(database1, false)
|
||||
|
||||
t.Log("Starting backup for database2")
|
||||
scheduler.StartBackup(database2.ID, false)
|
||||
scheduler.StartBackup(database2, false)
|
||||
|
||||
// Wait up to 10 seconds for both backups to complete
|
||||
t.Log("Waiting for both backups to complete...")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -322,7 +323,7 @@ func (c *BackupController) generateBackupFilename(
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
|
||||
// Sanitize database name for filename (replace spaces and special chars)
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
|
||||
// Determine extension based on database type
|
||||
extension := c.getBackupExtension(database.Type)
|
||||
@@ -346,33 +347,6 @@ func (c *BackupController) getBackupExtension(
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeFilename(name string) string {
|
||||
// Replace characters that are invalid in filenames
|
||||
replacer := map[rune]rune{
|
||||
' ': '_',
|
||||
'/': '-',
|
||||
'\\': '-',
|
||||
':': '-',
|
||||
'*': '-',
|
||||
'?': '-',
|
||||
'"': '-',
|
||||
'<': '-',
|
||||
'>': '-',
|
||||
'|': '-',
|
||||
}
|
||||
|
||||
result := make([]rune, 0, len(name))
|
||||
for _, char := range name {
|
||||
if replacement, exists := replacer[char]; exists {
|
||||
result = append(result, replacement)
|
||||
} else {
|
||||
result = append(result, char)
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
@@ -32,6 +32,7 @@ import (
|
||||
workspaces_models "databasus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
|
||||
"databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
test_utils "databasus-backend/internal/util/testing"
|
||||
"databasus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -956,7 +957,7 @@ func Test_SanitizeFilename(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.input, func(t *testing.T) {
|
||||
result := sanitizeFilename(tt.input)
|
||||
result := files_utils.SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
@@ -1407,7 +1408,7 @@ func createTestBackup(
|
||||
context.Background(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
logger,
|
||||
backup.ID,
|
||||
backup.ID.String(),
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/notifiers"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
@@ -23,7 +21,7 @@ type NotificationSender interface {
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
)
|
||||
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
|
||||
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
users_models "databasus-backend/internal/features/users/models"
|
||||
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
||||
util_encryption "databasus-backend/internal/util/encryption"
|
||||
files_utils "databasus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -92,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
s.backupSchedulerService.StartBackup(databaseID, true)
|
||||
s.backupSchedulerService.StartBackup(database, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
@@ -181,11 +182,7 @@ func (s *BackupService) DeleteBackup(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup deleted for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup deleted for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -232,11 +229,7 @@ func (s *BackupService) CancelBackup(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -276,11 +269,7 @@ func (s *BackupService) GetBackupFile(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -336,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
@@ -490,11 +479,7 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
database *databases.Database,
|
||||
) {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backup.ID.String(),
|
||||
),
|
||||
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
|
||||
&userID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -521,7 +506,7 @@ func (s *BackupService) generateBackupFilename(
|
||||
database *databases.Database,
|
||||
) string {
|
||||
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
|
||||
safeName := sanitizeFilename(database.Name)
|
||||
safeName := files_utils.SanitizeFilename(database.Name)
|
||||
extension := s.getBackupExtension(database.Type)
|
||||
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
|
||||
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
|
||||
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
|
||||
@@ -12,8 +13,6 @@ import (
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
"databasus-backend/internal/features/storages"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type CreateBackupUsecase struct {
|
||||
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
|
||||
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.CreateMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.CreateMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
case databases.DatabaseTypeMongodb:
|
||||
return uc.CreateMongodbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadbDump,
|
||||
@@ -136,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
@@ -187,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -204,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -46,7 +47,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMongodbExecutable(
|
||||
tools.MongodbExecutableMongodump,
|
||||
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
|
||||
|
||||
func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mongodumpBin string,
|
||||
args []string,
|
||||
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -52,7 +53,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
@@ -149,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
@@ -200,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -217,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"databasus-backend/internal/config"
|
||||
common "databasus-backend/internal/features/backups/backups/common"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "databasus-backend/internal/features/backups/config"
|
||||
"databasus-backend/internal/features/databases"
|
||||
@@ -53,7 +54,7 @@ type writeResult struct {
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backup,
|
||||
backupConfig,
|
||||
tools.GetPostgresqlExecutable(
|
||||
pg.Version,
|
||||
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
// streamToStorage streams pg_dump output directly to storage
|
||||
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backup *backups_core.Backup,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
pgBin string,
|
||||
args []string,
|
||||
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(
|
||||
ctx,
|
||||
uc.fieldEncryptor,
|
||||
uc.logger,
|
||||
backup.FileName,
|
||||
storageReader,
|
||||
)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
|
||||
@@ -192,6 +192,8 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
}
|
||||
}
|
||||
|
||||
oldName := existingDatabase.Name
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -201,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
if oldName != existingDatabase.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Database updated and renamed from '%s' to '%s'",
|
||||
oldName,
|
||||
existingDatabase.Name,
|
||||
),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -571,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
|
||||
database.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
|
||||
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
|
||||
nil,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
@@ -58,6 +58,8 @@ func (s *NotifierService) SaveNotifier(
|
||||
return err
|
||||
}
|
||||
|
||||
oldName := existingNotifier.Name
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -67,11 +69,23 @@ func (s *NotifierService) SaveNotifier(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
if oldName != existingNotifier.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Notifier updated and renamed from '%s' to '%s'",
|
||||
oldName,
|
||||
existingNotifier.Name,
|
||||
),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
notifier.WorkspaceID = workspaceID
|
||||
|
||||
@@ -343,9 +357,19 @@ func (s *NotifierService) TransferNotifierToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
|
||||
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Notifier transferred: %s from workspace '%s' to workspace '%s'",
|
||||
existingNotifier.Name, sourceWorkspace.Name, targetWorkspace.Name),
|
||||
&user.ID,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
@@ -261,7 +261,7 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Database restored from backup") &&
|
||||
if strings.Contains(log.Message, "Database restored for database") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
@@ -752,7 +752,7 @@ func createTestBackup(
|
||||
context.Background(),
|
||||
fieldEncryptor,
|
||||
logger,
|
||||
backup.ID,
|
||||
backup.ID.String(),
|
||||
reader,
|
||||
); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
|
||||
@@ -190,11 +190,7 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Database restored from backup %s for database: %s",
|
||||
backupID.String(),
|
||||
database.Name,
|
||||
),
|
||||
fmt.Sprintf("Database restored for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
@@ -412,11 +408,7 @@ func (s *RestoreService) CancelRestore(
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Restore cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
restoreID.String(),
|
||||
),
|
||||
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
@@ -106,7 +106,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -141,7 +141,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
// Stream backup directly from storage
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
|
||||
@@ -154,7 +154,7 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
|
||||
|
||||
// Stream backup directly from storage
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
|
||||
@@ -105,7 +105,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
@@ -140,7 +140,7 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
// Stream backup directly from storage
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
"--no-acl",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and parent cancellation
|
||||
@@ -209,7 +209,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
}
|
||||
|
||||
// Get backup stream from storage
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
@@ -429,7 +429,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 60*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and parent cancellation
|
||||
@@ -540,12 +540,14 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
"encrypted",
|
||||
backup.Encryption == backups_config.BackupEncryptionEncrypted,
|
||||
)
|
||||
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.FileName)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := rawReader.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close backup reader", "error", err)
|
||||
|
||||
@@ -14,13 +14,13 @@ type StorageFileSaver interface {
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error
|
||||
|
||||
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
|
||||
GetFile(encryptor encryption.FieldEncryptor, fileName string) (io.ReadCloser, error)
|
||||
|
||||
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
|
||||
DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error
|
||||
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
|
||||
@@ -41,10 +41,10 @@ func (s *Storage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileID, file)
|
||||
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileName, file)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
s.LastSaveError = &lastSaveError
|
||||
@@ -58,13 +58,13 @@ func (s *Storage) SaveFile(
|
||||
|
||||
func (s *Storage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
return s.getSpecificStorage().GetFile(encryptor, fileID)
|
||||
return s.getSpecificStorage().GetFile(encryptor, fileName)
|
||||
}
|
||||
|
||||
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
|
||||
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
return s.getSpecificStorage().DeleteFile(encryptor, fileName)
|
||||
}
|
||||
|
||||
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
|
||||
@@ -229,12 +229,12 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
|
||||
context.Background(),
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
fileID.String(),
|
||||
bytes.NewReader(fileData),
|
||||
)
|
||||
require.NoError(t, err, "SaveFile should succeed")
|
||||
|
||||
file, err := tc.storage.GetFile(encryptor, fileID)
|
||||
file, err := tc.storage.GetFile(encryptor, fileID.String())
|
||||
assert.NoError(t, err, "GetFile should succeed")
|
||||
defer file.Close()
|
||||
|
||||
@@ -252,15 +252,15 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
|
||||
context.Background(),
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
fileID.String(),
|
||||
bytes.NewReader(fileData),
|
||||
)
|
||||
require.NoError(t, err, "SaveFile should succeed")
|
||||
|
||||
err = tc.storage.DeleteFile(encryptor, fileID)
|
||||
err = tc.storage.DeleteFile(encryptor, fileID.String())
|
||||
assert.NoError(t, err, "DeleteFile should succeed")
|
||||
|
||||
file, err := tc.storage.GetFile(encryptor, fileID)
|
||||
file, err := tc.storage.GetFile(encryptor, fileID.String())
|
||||
assert.Error(t, err, "GetFile should fail for non-existent file")
|
||||
if file != nil {
|
||||
file.Close()
|
||||
@@ -270,7 +270,7 @@ acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoi
|
||||
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
|
||||
// Try to delete a non-existent file
|
||||
nonExistentID := uuid.New()
|
||||
err := tc.storage.DeleteFile(encryptor, nonExistentID)
|
||||
err := tc.storage.DeleteFile(encryptor, nonExistentID.String())
|
||||
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -68,7 +68,7 @@ func (s *AzureBlobStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -82,7 +82,7 @@ func (s *AzureBlobStorage) SaveFile(
|
||||
return err
|
||||
}
|
||||
|
||||
blobName := s.buildBlobName(fileID.String())
|
||||
blobName := s.buildBlobName(fileName)
|
||||
blockBlobClient := client.ServiceClient().
|
||||
NewContainerClient(s.ContainerName).
|
||||
NewBlockBlobClient(blobName)
|
||||
@@ -157,14 +157,14 @@ func (s *AzureBlobStorage) SaveFile(
|
||||
|
||||
func (s *AzureBlobStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blobName := s.buildBlobName(fileID.String())
|
||||
blobName := s.buildBlobName(fileName)
|
||||
|
||||
response, err := client.DownloadStream(
|
||||
context.TODO(),
|
||||
@@ -179,13 +179,13 @@ func (s *AzureBlobStorage) GetFile(
|
||||
return response.Body, nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blobName := s.buildBlobName(fileID.String())
|
||||
blobName := s.buildBlobName(fileName)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), azureDeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -41,7 +41,7 @@ func (f *FTPStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -50,19 +50,19 @@ func (f *FTPStorage) SaveFile(
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to FTP storage", "fileId", fileID.String(), "host", f.Host)
|
||||
logger.Info("Starting to save file to FTP storage", "fileName", fileName, "host", f.Host)
|
||||
|
||||
conn, err := f.connect(encryptor, ftpConnectTimeout)
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to FTP", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to connect to FTP", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if quitErr := conn.Quit(); quitErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close FTP connection",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
quitErr,
|
||||
)
|
||||
@@ -73,8 +73,8 @@ func (f *FTPStorage) SaveFile(
|
||||
if err := f.ensureDirectory(conn, f.Path); err != nil {
|
||||
logger.Error(
|
||||
"Failed to ensure directory",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"path",
|
||||
f.Path,
|
||||
"error",
|
||||
@@ -84,8 +84,8 @@ func (f *FTPStorage) SaveFile(
|
||||
}
|
||||
}
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file to FTP", "fileId", fileID.String(), "filePath", filePath)
|
||||
filePath := f.getFilePath(fileName)
|
||||
logger.Debug("Uploading file to FTP", "fileName", fileName, "filePath", filePath)
|
||||
|
||||
ctxReader := &contextReader{ctx: ctx, reader: file}
|
||||
|
||||
@@ -93,18 +93,18 @@ func (f *FTPStorage) SaveFile(
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("FTP upload cancelled", "fileId", fileID.String())
|
||||
logger.Info("FTP upload cancelled", "fileName", fileName)
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error("Failed to upload file to FTP", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to upload file to FTP", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to upload file to FTP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to FTP storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
@@ -113,14 +113,14 @@ func (f *FTPStorage) SaveFile(
|
||||
|
||||
func (f *FTPStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
conn, err := f.connect(encryptor, ftpConnectTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
filePath := f.getFilePath(fileName)
|
||||
|
||||
resp, err := conn.Retr(filePath)
|
||||
if err != nil {
|
||||
@@ -134,7 +134,7 @@ func (f *FTPStorage) GetFile(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ftpDeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -146,7 +146,7 @@ func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
|
||||
_ = conn.Quit()
|
||||
}()
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
filePath := f.getFilePath(fileName)
|
||||
|
||||
_, err = conn.FileSize(filePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -50,21 +50,19 @@ func (s *GoogleDriveStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
|
||||
filename := fileID.String()
|
||||
|
||||
folderID, err := s.ensureBackupsFolderExists(ctx, driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create/find backups folder: %w", err)
|
||||
}
|
||||
|
||||
_ = s.deleteByName(ctx, driveService, filename, folderID)
|
||||
_ = s.deleteByName(ctx, driveService, fileName, folderID)
|
||||
|
||||
fileMeta := &drive.File{
|
||||
Name: filename,
|
||||
Name: fileName,
|
||||
Parents: []string{folderID},
|
||||
}
|
||||
|
||||
@@ -91,7 +89,7 @@ func (s *GoogleDriveStorage) SaveFile(
|
||||
logger.Info(
|
||||
"file uploaded to Google Drive",
|
||||
"name",
|
||||
filename,
|
||||
fileName,
|
||||
"folder",
|
||||
"databasus_backups",
|
||||
)
|
||||
@@ -152,7 +150,7 @@ func (r *backpressureReader) Read(p []byte) (n int, err error) {
|
||||
|
||||
func (s *GoogleDriveStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
var result io.ReadCloser
|
||||
err := s.withRetryOnAuth(
|
||||
@@ -164,7 +162,7 @@ func (s *GoogleDriveStorage) GetFile(
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
}
|
||||
|
||||
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
|
||||
fileIDGoogle, err := s.lookupFileID(driveService, fileName, folderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -184,7 +182,7 @@ func (s *GoogleDriveStorage) GetFile(
|
||||
|
||||
func (s *GoogleDriveStorage) DeleteFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), gdDeleteTimeout)
|
||||
defer cancel()
|
||||
@@ -195,7 +193,7 @@ func (s *GoogleDriveStorage) DeleteFile(
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
}
|
||||
|
||||
return s.deleteByName(ctx, driveService, fileID.String(), folderID)
|
||||
return s.deleteByName(ctx, driveService, fileName, folderID)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ func (l *LocalStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -45,7 +45,7 @@ func (l *LocalStorage) SaveFile(
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
|
||||
logger.Info("Starting to save file to local storage", "fileName", fileName)
|
||||
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
@@ -54,15 +54,15 @@ func (l *LocalStorage) SaveFile(
|
||||
return fmt.Errorf("failed to ensure directories: %w", err)
|
||||
}
|
||||
|
||||
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileID.String())
|
||||
logger.Debug("Creating temp file", "fileId", fileID.String(), "tempPath", tempFilePath)
|
||||
tempFilePath := filepath.Join(config.GetEnv().TempFolder, fileName)
|
||||
logger.Debug("Creating temp file", "fileName", fileName, "tempPath", tempFilePath)
|
||||
|
||||
tempFile, err := os.Create(tempFilePath)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to create temp file",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"tempPath",
|
||||
tempFilePath,
|
||||
"error",
|
||||
@@ -74,29 +74,29 @@ func (l *LocalStorage) SaveFile(
|
||||
_ = tempFile.Close()
|
||||
}()
|
||||
|
||||
logger.Debug("Copying file data to temp file", "fileId", fileID.String())
|
||||
logger.Debug("Copying file data to temp file", "fileName", fileName)
|
||||
_, err = copyWithContext(ctx, tempFile, file)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write to temp file", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to write to temp file", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to write to temp file: %w", err)
|
||||
}
|
||||
|
||||
if err = tempFile.Sync(); err != nil {
|
||||
logger.Error("Failed to sync temp file", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to sync temp file", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to sync temp file: %w", err)
|
||||
}
|
||||
|
||||
// Close the temp file explicitly before moving it (required on Windows)
|
||||
if err = tempFile.Close(); err != nil {
|
||||
logger.Error("Failed to close temp file", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to close temp file", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to close temp file: %w", err)
|
||||
}
|
||||
|
||||
finalPath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
|
||||
finalPath := filepath.Join(config.GetEnv().DataFolder, fileName)
|
||||
logger.Debug(
|
||||
"Moving file from temp to final location",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"finalPath",
|
||||
finalPath,
|
||||
)
|
||||
@@ -105,8 +105,8 @@ func (l *LocalStorage) SaveFile(
|
||||
if err = os.Rename(tempFilePath, finalPath); err != nil {
|
||||
logger.Error(
|
||||
"Failed to move file from temp to backups",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"tempPath",
|
||||
tempFilePath,
|
||||
"finalPath",
|
||||
@@ -119,8 +119,8 @@ func (l *LocalStorage) SaveFile(
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to local storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"finalPath",
|
||||
finalPath,
|
||||
)
|
||||
@@ -130,12 +130,12 @@ func (l *LocalStorage) SaveFile(
|
||||
|
||||
func (l *LocalStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("file not found: %s", fileID.String())
|
||||
return nil, fmt.Errorf("file not found: %s", fileName)
|
||||
}
|
||||
|
||||
file, err := os.Open(filePath)
|
||||
@@ -146,8 +146,8 @@ func (l *LocalStorage) GetFile(
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
|
||||
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileName)
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return nil
|
||||
|
||||
@@ -46,7 +46,7 @@ func (n *NASStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -55,19 +55,19 @@ func (n *NASStorage) SaveFile(
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
|
||||
logger.Info("Starting to save file to NAS storage", "fileName", fileName, "host", n.Host)
|
||||
|
||||
session, err := n.createSessionWithContext(ctx, encryptor)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to create NAS session", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to create NAS session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if logoffErr := session.Logoff(); logoffErr != nil {
|
||||
logger.Error(
|
||||
"Failed to logoff NAS session",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
logoffErr,
|
||||
)
|
||||
@@ -78,8 +78,8 @@ func (n *NASStorage) SaveFile(
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to mount NAS share",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"share",
|
||||
n.Share,
|
||||
"error",
|
||||
@@ -91,8 +91,8 @@ func (n *NASStorage) SaveFile(
|
||||
if umountErr := fs.Umount(); umountErr != nil {
|
||||
logger.Error(
|
||||
"Failed to unmount NAS share",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
umountErr,
|
||||
)
|
||||
@@ -104,8 +104,8 @@ func (n *NASStorage) SaveFile(
|
||||
if err := n.ensureDirectory(fs, n.Path); err != nil {
|
||||
logger.Error(
|
||||
"Failed to ensure directory",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"path",
|
||||
n.Path,
|
||||
"error",
|
||||
@@ -115,15 +115,15 @@ func (n *NASStorage) SaveFile(
|
||||
}
|
||||
}
|
||||
|
||||
filePath := n.getFilePath(fileID.String())
|
||||
logger.Debug("Creating file on NAS", "fileId", fileID.String(), "filePath", filePath)
|
||||
filePath := n.getFilePath(fileName)
|
||||
logger.Debug("Creating file on NAS", "fileName", fileName, "filePath", filePath)
|
||||
|
||||
nasFile, err := fs.Create(filePath)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to create file on NAS",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"filePath",
|
||||
filePath,
|
||||
"error",
|
||||
@@ -133,21 +133,21 @@ func (n *NASStorage) SaveFile(
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := nasFile.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close NAS file", "fileId", fileID.String(), "error", closeErr)
|
||||
logger.Error("Failed to close NAS file", "fileName", fileName, "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Debug("Copying file data to NAS", "fileId", fileID.String())
|
||||
logger.Debug("Copying file data to NAS", "fileName", fileName)
|
||||
_, err = copyWithContext(ctx, nasFile, file)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write file to NAS", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to write file to NAS", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to write file to NAS: %w", err)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to NAS storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
@@ -156,7 +156,7 @@ func (n *NASStorage) SaveFile(
|
||||
|
||||
func (n *NASStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
session, err := n.createSession(encryptor)
|
||||
if err != nil {
|
||||
@@ -169,14 +169,14 @@ func (n *NASStorage) GetFile(
|
||||
return nil, fmt.Errorf("failed to mount share '%s': %w", n.Share, err)
|
||||
}
|
||||
|
||||
filePath := n.getFilePath(fileID.String())
|
||||
filePath := n.getFilePath(fileName)
|
||||
|
||||
// Check if file exists
|
||||
_, err = fs.Stat(filePath)
|
||||
if err != nil {
|
||||
_ = fs.Umount()
|
||||
_ = session.Logoff()
|
||||
return nil, fmt.Errorf("file not found: %s", fileID.String())
|
||||
return nil, fmt.Errorf("file not found: %s", fileName)
|
||||
}
|
||||
|
||||
nasFile, err := fs.Open(filePath)
|
||||
@@ -194,7 +194,7 @@ func (n *NASStorage) GetFile(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), nasDeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -214,7 +214,7 @@ func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid
|
||||
_ = fs.Umount()
|
||||
}()
|
||||
|
||||
filePath := n.getFilePath(fileID.String())
|
||||
filePath := n.getFilePath(fileName)
|
||||
|
||||
_, err = fs.Stat(filePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -41,7 +41,7 @@ func (r *RcloneStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -50,28 +50,28 @@ func (r *RcloneStorage) SaveFile(
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to rclone storage", "fileId", fileID.String())
|
||||
logger.Info("Starting to save file to rclone storage", "fileName", fileName)
|
||||
|
||||
remoteFs, err := r.getFs(ctx, encryptor)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create rclone filesystem", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to create rclone filesystem", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file via rclone", "fileId", fileID.String(), "filePath", filePath)
|
||||
filePath := r.getFilePath(fileName)
|
||||
logger.Debug("Uploading file via rclone", "fileName", fileName, "filePath", filePath)
|
||||
|
||||
_, err = operations.Rcat(ctx, remoteFs, filePath, io.NopCloser(file), time.Now().UTC(), nil)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Rclone upload cancelled", "fileId", fileID.String())
|
||||
logger.Info("Rclone upload cancelled", "fileName", fileName)
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error(
|
||||
"Failed to upload file via rclone",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
@@ -81,8 +81,8 @@ func (r *RcloneStorage) SaveFile(
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to rclone storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
@@ -91,7 +91,7 @@ func (r *RcloneStorage) SaveFile(
|
||||
|
||||
func (r *RcloneStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -100,7 +100,7 @@ func (r *RcloneStorage) GetFile(
|
||||
return nil, fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
filePath := r.getFilePath(fileName)
|
||||
|
||||
obj, err := remoteFs.NewObject(ctx, filePath)
|
||||
if err != nil {
|
||||
@@ -115,7 +115,7 @@ func (r *RcloneStorage) GetFile(
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), rcloneDeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -124,7 +124,7 @@ func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID u
|
||||
return fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
filePath := r.getFilePath(fileName)
|
||||
|
||||
obj, err := remoteFs.NewObject(ctx, filePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -55,7 +55,7 @@ func (s *S3Storage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -69,7 +69,7 @@ func (s *S3Storage) SaveFile(
|
||||
return err
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(fileID.String())
|
||||
objectKey := s.buildObjectKey(fileName)
|
||||
|
||||
uploadID, err := coreClient.NewMultipartUpload(
|
||||
ctx,
|
||||
@@ -184,14 +184,14 @@ func (s *S3Storage) SaveFile(
|
||||
|
||||
func (s *S3Storage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(fileID.String())
|
||||
objectKey := s.buildObjectKey(fileName)
|
||||
|
||||
object, err := client.GetObject(
|
||||
context.TODO(),
|
||||
@@ -221,13 +221,13 @@ func (s *S3Storage) GetFile(
|
||||
return object, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(fileID.String())
|
||||
objectKey := s.buildObjectKey(fileName)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s3DeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -41,7 +41,7 @@ func (s *SFTPStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
@@ -50,19 +50,19 @@ func (s *SFTPStorage) SaveFile(
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to SFTP storage", "fileId", fileID.String(), "host", s.Host)
|
||||
logger.Info("Starting to save file to SFTP storage", "fileName", fileName, "host", s.Host)
|
||||
|
||||
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to SFTP", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to connect to SFTP", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close SFTP client",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
@@ -70,8 +70,8 @@ func (s *SFTPStorage) SaveFile(
|
||||
if closeErr := sshConn.Close(); closeErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close SSH connection",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
@@ -82,8 +82,8 @@ func (s *SFTPStorage) SaveFile(
|
||||
if err := s.ensureDirectory(client, s.Path); err != nil {
|
||||
logger.Error(
|
||||
"Failed to ensure directory",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"path",
|
||||
s.Path,
|
||||
"error",
|
||||
@@ -93,12 +93,12 @@ func (s *SFTPStorage) SaveFile(
|
||||
}
|
||||
}
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file to SFTP", "fileId", fileID.String(), "filePath", filePath)
|
||||
filePath := s.getFilePath(fileName)
|
||||
logger.Debug("Uploading file to SFTP", "fileName", fileName, "filePath", filePath)
|
||||
|
||||
remoteFile, err := client.Create(filePath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create remote file", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to create remote file", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to create remote file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
@@ -111,18 +111,18 @@ func (s *SFTPStorage) SaveFile(
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("SFTP upload cancelled", "fileId", fileID.String())
|
||||
logger.Info("SFTP upload cancelled", "fileName", fileName)
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error("Failed to upload file to SFTP", "fileId", fileID.String(), "error", err)
|
||||
logger.Error("Failed to upload file to SFTP", "fileName", fileName, "error", err)
|
||||
return fmt.Errorf("failed to upload file to SFTP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to SFTP storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"fileName",
|
||||
fileName,
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
@@ -131,14 +131,14 @@ func (s *SFTPStorage) SaveFile(
|
||||
|
||||
func (s *SFTPStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
fileName string,
|
||||
) (io.ReadCloser, error) {
|
||||
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
filePath := s.getFilePath(fileName)
|
||||
|
||||
remoteFile, err := client.Open(filePath)
|
||||
if err != nil {
|
||||
@@ -154,7 +154,7 @@ func (s *SFTPStorage) GetFile(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileName string) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sftpDeleteTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -167,7 +167,7 @@ func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uui
|
||||
_ = sshConn.Close()
|
||||
}()
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
filePath := s.getFilePath(fileName)
|
||||
|
||||
_, err = client.Stat(filePath)
|
||||
if err != nil {
|
||||
|
||||
@@ -92,6 +92,8 @@ func (s *StorageService) SaveStorage(
|
||||
|
||||
existingStorage.Update(storage)
|
||||
|
||||
oldName := existingStorage.Name
|
||||
|
||||
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -105,11 +107,19 @@ func (s *StorageService) SaveStorage(
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
if oldName != existingStorage.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Storage renamed from '%s' to '%s'", oldName, existingStorage.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
storage.WorkspaceID = workspaceID
|
||||
|
||||
@@ -368,9 +378,26 @@ func (s *StorageService) TransferStorageToWorkspace(
|
||||
return err
|
||||
}
|
||||
|
||||
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get source workspace: %w", err)
|
||||
}
|
||||
|
||||
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get target workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Storage transferred: %s from workspace %s to workspace %s",
|
||||
existingStorage.Name, sourceWorkspaceID, targetWorkspaceID),
|
||||
fmt.Sprintf("Storage transferred out: %s to workspace '%s'",
|
||||
existingStorage.Name, targetWorkspace.Name),
|
||||
&user.ID,
|
||||
&sourceWorkspaceID,
|
||||
)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Storage transferred in: %s from workspace '%s'",
|
||||
existingStorage.Name, sourceWorkspace.Name),
|
||||
&user.ID,
|
||||
&targetWorkspaceID,
|
||||
)
|
||||
|
||||
@@ -726,41 +726,28 @@ func Test_InviteUserToWorkspace_MembershipReceivedAfterSignUp(t *testing.T) {
|
||||
|
||||
assert.Equal(t, workspaces_dto.AddStatusInvited, inviteResponse.Status)
|
||||
|
||||
// 3. Sign up the invited user
|
||||
// 3. Sign up the invited user (now returns token directly)
|
||||
signUpRequest := users_dto.SignUpRequestDTO{
|
||||
Email: inviteEmail,
|
||||
Password: "testpassword123",
|
||||
Name: "Invited User",
|
||||
}
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
var signInResponse users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signup",
|
||||
"",
|
||||
signUpRequest,
|
||||
http.StatusOK,
|
||||
)
|
||||
assert.Contains(t, string(resp.Body), "User created successfully")
|
||||
|
||||
// 4. Sign in the newly registered user
|
||||
signInRequest := users_dto.SignInRequestDTO{
|
||||
Email: inviteEmail,
|
||||
Password: "testpassword123",
|
||||
}
|
||||
|
||||
var signInResponse users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signin",
|
||||
"",
|
||||
signInRequest,
|
||||
http.StatusOK,
|
||||
&signInResponse,
|
||||
)
|
||||
|
||||
// 5. Verify user is automatically added as member to workspace
|
||||
assert.NotEmpty(t, signInResponse.Token)
|
||||
assert.Equal(t, inviteEmail, signInResponse.Email)
|
||||
|
||||
// 4. Verify user is automatically added as member to workspace
|
||||
var membersResponse workspaces_dto.GetMembersResponseDTO
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
user_middleware "databasus-backend/internal/features/users/middleware"
|
||||
users_services "databasus-backend/internal/features/users/services"
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
cloudflare_turnstile "databasus-backend/internal/util/cloudflare_turnstile"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -51,7 +52,7 @@ func (c *UserController) RegisterProtectedRoutes(router *gin.RouterGroup) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body users_dto.SignUpRequestDTO true "User signup data"
|
||||
// @Success 200
|
||||
// @Success 200 {object} users_dto.SignInResponseDTO
|
||||
// @Failure 400
|
||||
// @Router /users/signup [post]
|
||||
func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
@@ -61,13 +62,41 @@ func (c *UserController) SignUp(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
err := c.userService.SignUp(&request)
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
user, err := c.userService.SignUp(&request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
|
||||
response, err := c.userService.GenerateAccessToken(user)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// SignIn
|
||||
@@ -88,6 +117,28 @@ func (c *UserController) SignIn(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(request.Email, "signin", 10, 1*time.Minute)
|
||||
if !allowed {
|
||||
ctx.JSON(
|
||||
@@ -363,6 +414,28 @@ func (c *UserController) SendResetPasswordCode(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Verify Cloudflare Turnstile if enabled
|
||||
turnstileService := cloudflare_turnstile.GetCloudflareTurnstileService()
|
||||
if turnstileService.IsEnabled() {
|
||||
if request.CloudflareTurnstileToken == nil || *request.CloudflareTurnstileToken == "" {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification required"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := ctx.ClientIP()
|
||||
isValid, err := turnstileService.VerifyToken(*request.CloudflareTurnstileToken, clientIP)
|
||||
if err != nil || !isValid {
|
||||
ctx.JSON(
|
||||
http.StatusBadRequest,
|
||||
gin.H{"error": "Cloudflare Turnstile verification failed"},
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
allowed, _ := c.rateLimiter.CheckLimit(
|
||||
request.Email,
|
||||
"reset-password",
|
||||
|
||||
@@ -27,7 +27,20 @@ func Test_SignUpUser_WithValidData_UserCreated(t *testing.T) {
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
test_utils.MakePostRequest(t, router, "/api/v1/users/signup", "", request, http.StatusOK)
|
||||
var response users_dto.SignInResponseDTO
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/users/signup",
|
||||
"",
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, response.Token)
|
||||
assert.NotEqual(t, uuid.Nil, response.UserID)
|
||||
assert.Equal(t, request.Email, response.Email)
|
||||
}
|
||||
|
||||
func Test_SignUpUser_WithInvalidJSON_ReturnsBadRequest(t *testing.T) {
|
||||
|
||||
@@ -9,14 +9,16 @@ import (
|
||||
)
|
||||
|
||||
type SignUpRequestDTO struct {
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required,min=8"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type SignInRequestDTO struct {
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type SignInResponseDTO struct {
|
||||
@@ -94,7 +96,8 @@ type OAuthCallbackResponseDTO struct {
|
||||
}
|
||||
|
||||
type SendResetPasswordCodeRequestDTO struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
CloudflareTurnstileToken *string `json:"cloudflareTurnstileToken"`
|
||||
}
|
||||
|
||||
type ResetPasswordRequestDTO struct {
|
||||
|
||||
@@ -44,19 +44,19 @@ func (s *UserService) SetEmailSender(sender users_interfaces.EmailSender) {
|
||||
s.emailSender = sender
|
||||
}
|
||||
|
||||
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) (*users_models.User, error) {
|
||||
existingUser, err := s.userRepository.GetUserByEmail(request.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check existing user: %w", err)
|
||||
return nil, fmt.Errorf("failed to check existing user: %w", err)
|
||||
}
|
||||
|
||||
if existingUser != nil && existingUser.Status != users_enums.UserStatusInvited {
|
||||
return errors.New("user with this email already exists")
|
||||
return nil, errors.New("user with this email already exists")
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(request.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash password: %w", err)
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
hashedPasswordStr := string(hashedPassword)
|
||||
@@ -67,39 +67,45 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
existingUser.ID,
|
||||
hashedPasswordStr,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to set password: %w", err)
|
||||
return nil, fmt.Errorf("failed to set password: %w", err)
|
||||
}
|
||||
|
||||
if err := s.userRepository.UpdateUserStatus(
|
||||
existingUser.ID,
|
||||
users_enums.UserStatusActive,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to activate user: %w", err)
|
||||
return nil, fmt.Errorf("failed to activate user: %w", err)
|
||||
}
|
||||
|
||||
name := request.Name
|
||||
if err := s.userRepository.UpdateUserInfo(existingUser.ID, &name, nil); err != nil {
|
||||
return fmt.Errorf("failed to update name: %w", err)
|
||||
return nil, fmt.Errorf("failed to update name: %w", err)
|
||||
}
|
||||
|
||||
// Fetch updated user to ensure we have the latest data
|
||||
updatedUser, err := s.userRepository.GetUserByID(existingUser.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get updated user: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Invited user completed registration: %s", existingUser.Email),
|
||||
&existingUser.ID,
|
||||
fmt.Sprintf("Invited user completed registration: %s", updatedUser.Email),
|
||||
&updatedUser.ID,
|
||||
nil,
|
||||
)
|
||||
|
||||
return nil
|
||||
return updatedUser, nil
|
||||
}
|
||||
|
||||
// Get settings to check registration policy for new users
|
||||
settings, err := s.settingsService.GetSettings()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get settings: %w", err)
|
||||
return nil, fmt.Errorf("failed to get settings: %w", err)
|
||||
}
|
||||
|
||||
// Check if external registrations are allowed
|
||||
if !settings.IsAllowExternalRegistrations {
|
||||
return errors.New("external registration is disabled")
|
||||
return nil, errors.New("external registration is disabled")
|
||||
}
|
||||
|
||||
user := &users_models.User{
|
||||
@@ -114,7 +120,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
}
|
||||
|
||||
if err := s.userRepository.CreateUser(user); err != nil {
|
||||
return fmt.Errorf("failed to create user: %w", err)
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
@@ -123,7 +129,7 @@ func (s *UserService) SignUp(request *users_dto.SignUpRequestDTO) error {
|
||||
nil,
|
||||
)
|
||||
|
||||
return nil
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *UserService) SignIn(
|
||||
@@ -258,6 +264,7 @@ func (s *UserService) GenerateAccessToken(
|
||||
|
||||
return &users_dto.SignInResponseDTO{
|
||||
UserID: user.ID,
|
||||
Email: user.Email,
|
||||
Token: tokenString,
|
||||
}, nil
|
||||
}
|
||||
@@ -383,7 +390,7 @@ func (s *UserService) InviteUser(
|
||||
|
||||
message := fmt.Sprintf("User invited: %s", request.Email)
|
||||
if request.IntendedWorkspaceID != nil {
|
||||
message += fmt.Sprintf(" for workspace %s", request.IntendedWorkspaceID.String())
|
||||
message += " for workspace"
|
||||
}
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
message,
|
||||
@@ -430,6 +437,9 @@ func (s *UserService) UpdateUserInfo(
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
oldEmail := user.Email
|
||||
oldName := user.Name
|
||||
|
||||
if user.Email == "admin" && request.Email != nil && *request.Email != user.Email {
|
||||
return errors.New("admin email cannot be changed")
|
||||
}
|
||||
@@ -448,7 +458,28 @@ func (s *UserService) UpdateUserInfo(
|
||||
return fmt.Errorf("failed to update user info: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogWriter.WriteAuditLog("User info updated", &userID, nil)
|
||||
var auditMessages []string
|
||||
if request.Email != nil && *request.Email != oldEmail {
|
||||
auditMessages = append(
|
||||
auditMessages,
|
||||
fmt.Sprintf("Email changed from '%s' to '%s'", oldEmail, *request.Email),
|
||||
)
|
||||
}
|
||||
if request.Name != nil && *request.Name != oldName {
|
||||
auditMessages = append(
|
||||
auditMessages,
|
||||
fmt.Sprintf("Name changed from '%s' to '%s'", oldName, *request.Name),
|
||||
)
|
||||
}
|
||||
|
||||
if len(auditMessages) > 0 {
|
||||
for _, message := range auditMessages {
|
||||
s.auditLogWriter.WriteAuditLog(message, &userID, nil)
|
||||
}
|
||||
} else {
|
||||
s.auditLogWriter.WriteAuditLog("User info updated", &userID, nil)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -463,6 +494,178 @@ func (s *UserService) HandleGitHubOAuth(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) HandleGoogleOAuth(
|
||||
code, redirectUri string,
|
||||
) (*users_dto.OAuthCallbackResponseDTO, error) {
|
||||
return s.handleGoogleOAuthWithEndpoint(
|
||||
code,
|
||||
redirectUri,
|
||||
google.Endpoint,
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) SendResetPasswordCode(email string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Silently succeed for non-existent users to prevent enumeration attacks
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only active users can reset passwords
|
||||
if user.Status != users_enums.UserStatusActive {
|
||||
return errors.New("only active users can reset their password")
|
||||
}
|
||||
|
||||
// Check rate limiting - max 3 codes per hour
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check rate limit: %w", err)
|
||||
}
|
||||
|
||||
if recentCount >= 3 {
|
||||
return errors.New("too many password reset attempts, please try again later")
|
||||
}
|
||||
|
||||
// Generate 6-digit random code using crypto/rand for better randomness
|
||||
codeNum := make([]byte, 4)
|
||||
_, err = io.ReadFull(rand.Reader, codeNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate random code: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to uint32 and modulo to get 6 digits
|
||||
randomInt := uint32(
|
||||
codeNum[0],
|
||||
)<<24 | uint32(
|
||||
codeNum[1],
|
||||
)<<16 | uint32(
|
||||
codeNum[2],
|
||||
)<<8 | uint32(
|
||||
codeNum[3],
|
||||
)
|
||||
code := fmt.Sprintf("%06d", randomInt%1000000)
|
||||
|
||||
// Hash the code
|
||||
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash code: %w", err)
|
||||
}
|
||||
|
||||
// Store in database with 1 hour expiration
|
||||
resetCode := &users_models.PasswordResetCode{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
HashedCode: string(hashedCode),
|
||||
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
|
||||
IsUsed: false,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
|
||||
return fmt.Errorf("failed to create reset code: %w", err)
|
||||
}
|
||||
|
||||
// Send email with code
|
||||
if s.emailSender != nil {
|
||||
subject := "Password Reset Code"
|
||||
body := fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
|
||||
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
|
||||
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
You have requested to reset your password. Please use the following code to complete the password reset process:
|
||||
</p>
|
||||
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
|
||||
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
|
||||
</div>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
This code will expire in <strong>1 hour</strong>.
|
||||
</p>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
|
||||
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
|
||||
This is an automated message. Please do not reply to this email.
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, code)
|
||||
|
||||
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Password reset code sent to: %s", user.Email),
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPassword(email, code, newPassword string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return errors.New("user with this email does not exist")
|
||||
}
|
||||
|
||||
// Get valid reset code for user
|
||||
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
|
||||
if err != nil {
|
||||
return errors.New("invalid or expired reset code")
|
||||
}
|
||||
|
||||
// Verify code matches
|
||||
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
|
||||
if err != nil {
|
||||
return errors.New("invalid reset code")
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
|
||||
return fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Update user password
|
||||
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
"Password reset via email code",
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) handleGitHubOAuthWithEndpoint(
|
||||
code, redirectUri string,
|
||||
endpoint oauth2.Endpoint,
|
||||
@@ -529,17 +732,6 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
|
||||
return s.getOrCreateUserFromOAuth(oauthID, email, name, "github")
|
||||
}
|
||||
|
||||
func (s *UserService) HandleGoogleOAuth(
|
||||
code, redirectUri string,
|
||||
) (*users_dto.OAuthCallbackResponseDTO, error) {
|
||||
return s.handleGoogleOAuthWithEndpoint(
|
||||
code,
|
||||
redirectUri,
|
||||
google.Endpoint,
|
||||
"https://www.googleapis.com/oauth2/v2/userinfo",
|
||||
)
|
||||
}
|
||||
|
||||
func (s *UserService) handleGoogleOAuthWithEndpoint(
|
||||
code, redirectUri string,
|
||||
endpoint oauth2.Endpoint,
|
||||
@@ -805,164 +997,3 @@ func (s *UserService) fetchGitHubPrimaryEmail(
|
||||
|
||||
return "", errors.New("github account has no accessible email")
|
||||
}
|
||||
|
||||
func (s *UserService) SendResetPasswordCode(email string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
// Silently succeed for non-existent users to prevent enumeration attacks
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only active users can reset passwords
|
||||
if user.Status != users_enums.UserStatusActive {
|
||||
return errors.New("only active users can reset their password")
|
||||
}
|
||||
|
||||
// Check rate limiting - max 3 codes per hour
|
||||
oneHourAgo := time.Now().UTC().Add(-1 * time.Hour)
|
||||
recentCount, err := s.passwordResetRepository.CountRecentCodesByUserID(user.ID, oneHourAgo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check rate limit: %w", err)
|
||||
}
|
||||
|
||||
if recentCount >= 3 {
|
||||
return errors.New("too many password reset attempts, please try again later")
|
||||
}
|
||||
|
||||
// Generate 6-digit random code using crypto/rand for better randomness
|
||||
codeNum := make([]byte, 4)
|
||||
_, err = io.ReadFull(rand.Reader, codeNum)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate random code: %w", err)
|
||||
}
|
||||
|
||||
// Convert bytes to uint32 and modulo to get 6 digits
|
||||
randomInt := uint32(
|
||||
codeNum[0],
|
||||
)<<24 | uint32(
|
||||
codeNum[1],
|
||||
)<<16 | uint32(
|
||||
codeNum[2],
|
||||
)<<8 | uint32(
|
||||
codeNum[3],
|
||||
)
|
||||
code := fmt.Sprintf("%06d", randomInt%1000000)
|
||||
|
||||
// Hash the code
|
||||
hashedCode, err := bcrypt.GenerateFromPassword([]byte(code), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash code: %w", err)
|
||||
}
|
||||
|
||||
// Store in database with 1 hour expiration
|
||||
resetCode := &users_models.PasswordResetCode{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
HashedCode: string(hashedCode),
|
||||
ExpiresAt: time.Now().UTC().Add(1 * time.Hour),
|
||||
IsUsed: false,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.passwordResetRepository.CreateResetCode(resetCode); err != nil {
|
||||
return fmt.Errorf("failed to create reset code: %w", err)
|
||||
}
|
||||
|
||||
// Send email with code
|
||||
if s.emailSender != nil {
|
||||
subject := "Password Reset Code"
|
||||
body := fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
</head>
|
||||
<body style="margin: 0; padding: 0; font-family: Arial, sans-serif; background-color: #f4f4f4;">
|
||||
<div style="max-width: 600px; margin: 0 auto; background-color: #ffffff; padding: 20px;">
|
||||
<h2 style="color: #333333; margin-bottom: 20px;">Password Reset Request</h2>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
You have requested to reset your password. Please use the following code to complete the password reset process:
|
||||
</p>
|
||||
<div style="background-color: #f8f9fa; border: 2px solid #e9ecef; border-radius: 8px; padding: 20px; text-align: center; margin: 30px 0;">
|
||||
<h1 style="color: #2c3e50; font-size: 36px; margin: 0; letter-spacing: 8px; font-family: monospace;">%s</h1>
|
||||
</div>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
This code will expire in <strong>1 hour</strong>.
|
||||
</p>
|
||||
<p style="color: #666666; line-height: 1.6; margin-bottom: 20px;">
|
||||
If you did not request a password reset, please ignore this email. Your password will remain unchanged.
|
||||
</p>
|
||||
<hr style="border: none; border-top: 1px solid #e9ecef; margin: 30px 0;">
|
||||
<p style="color: #999999; font-size: 12px; line-height: 1.6;">
|
||||
This is an automated message. Please do not reply to this email.
|
||||
</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, code)
|
||||
|
||||
if err := s.emailSender.SendEmail(user.Email, subject, body); err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
fmt.Sprintf("Password reset code sent to: %s", user.Email),
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UserService) ResetPassword(email, code, newPassword string) error {
|
||||
user, err := s.userRepository.GetUserByEmail(email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return errors.New("user with this email does not exist")
|
||||
}
|
||||
|
||||
// Get valid reset code for user
|
||||
resetCode, err := s.passwordResetRepository.GetValidCodeByUserID(user.ID)
|
||||
if err != nil {
|
||||
return errors.New("invalid or expired reset code")
|
||||
}
|
||||
|
||||
// Verify code matches
|
||||
err = bcrypt.CompareHashAndPassword([]byte(resetCode.HashedCode), []byte(code))
|
||||
if err != nil {
|
||||
return errors.New("invalid reset code")
|
||||
}
|
||||
|
||||
// Mark code as used
|
||||
if err := s.passwordResetRepository.MarkCodeAsUsed(resetCode.ID); err != nil {
|
||||
return fmt.Errorf("failed to mark code as used: %w", err)
|
||||
}
|
||||
|
||||
// Update user password
|
||||
if err := s.ChangeUserPassword(user.ID, newPassword); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// Audit log
|
||||
if s.auditLogWriter != nil {
|
||||
s.auditLogWriter.WriteAuditLog(
|
||||
"Password reset via email code",
|
||||
&user.ID,
|
||||
nil,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -129,6 +129,8 @@ func (s *WorkspaceService) UpdateWorkspace(
|
||||
return nil, fmt.Errorf("failed to get workspace: %w", err)
|
||||
}
|
||||
|
||||
oldName := existingWorkspace.Name
|
||||
|
||||
updateDTO.ID = workspaceID
|
||||
updateDTO.CreatedAt = existingWorkspace.CreatedAt
|
||||
|
||||
@@ -138,11 +140,19 @@ func (s *WorkspaceService) UpdateWorkspace(
|
||||
return nil, fmt.Errorf("failed to update workspace: %w", err)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Workspace updated: %s", updateDTO.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
if oldName != updateDTO.Name {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Workspace updated and renamed from '%s' to '%s'", oldName, updateDTO.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
} else {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Workspace updated: %s", updateDTO.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return existingWorkspace, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package cloudflare_turnstile
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
type CloudflareTurnstileService struct {
|
||||
secretKey string
|
||||
siteKey string
|
||||
}
|
||||
|
||||
type cloudflareTurnstileResponse struct {
|
||||
Success bool `json:"success"`
|
||||
ChallengeTS time.Time `json:"challenge_ts"`
|
||||
Hostname string `json:"hostname"`
|
||||
ErrorCodes []string `json:"error-codes"`
|
||||
}
|
||||
|
||||
const cloudflareTurnstileVerifyURL = "https://challenges.cloudflare.com/turnstile/v0/siteverify"
|
||||
|
||||
func (s *CloudflareTurnstileService) IsEnabled() bool {
|
||||
return s.secretKey != ""
|
||||
}
|
||||
|
||||
func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool, error) {
|
||||
if !s.IsEnabled() {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
return false, errors.New("cloudflare Turnstile token is required")
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("secret", s.secretKey)
|
||||
formData.Set("response", token)
|
||||
formData.Set("remoteip", remoteIP)
|
||||
|
||||
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read Cloudflare Turnstile response: %w", err)
|
||||
}
|
||||
|
||||
var turnstileResp cloudflareTurnstileResponse
|
||||
if err := json.Unmarshal(body, &turnstileResp); err != nil {
|
||||
return false, fmt.Errorf("failed to parse Cloudflare Turnstile response: %w", err)
|
||||
}
|
||||
|
||||
if !turnstileResp.Success {
|
||||
return false, fmt.Errorf(
|
||||
"cloudflare Turnstile verification failed: %v",
|
||||
turnstileResp.ErrorCodes,
|
||||
)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
14
backend/internal/util/cloudflare_turnstile/di.go
Normal file
14
backend/internal/util/cloudflare_turnstile/di.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package cloudflare_turnstile
|
||||
|
||||
import (
|
||||
"databasus-backend/internal/config"
|
||||
)
|
||||
|
||||
var cloudflareTurnstileService = &CloudflareTurnstileService{
|
||||
config.GetEnv().CloudflareTurnstileSecretKey,
|
||||
config.GetEnv().CloudflareTurnstileSiteKey,
|
||||
}
|
||||
|
||||
func GetCloudflareTurnstileService() *CloudflareTurnstileService {
|
||||
return cloudflareTurnstileService
|
||||
}
|
||||
48
backend/internal/util/files/sanitizer.go
Normal file
48
backend/internal/util/files/sanitizer.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package files_utils
|
||||
|
||||
// SanitizeFilename replaces characters that are invalid or problematic in filenames
|
||||
// across different operating systems (Windows, Linux, macOS) and storage systems
|
||||
// (local filesystem, S3, FTP, SFTP, NAS, rclone, Azure Blob, Google Drive).
|
||||
//
|
||||
// The following characters are replaced:
|
||||
// - Space (' ') -> underscore ('_')
|
||||
// - Forward slash ('/') -> hyphen ('-')
|
||||
// - Backslash ('\') -> hyphen ('-')
|
||||
// - Colon (':') -> hyphen ('-')
|
||||
// - Asterisk ('*') -> hyphen ('-')
|
||||
// - Question mark ('?') -> hyphen ('-')
|
||||
// - Double quote ('"') -> hyphen ('-')
|
||||
// - Less than ('<') -> hyphen ('-')
|
||||
// - Greater than ('>') -> hyphen ('-')
|
||||
// - Pipe ('|') -> hyphen ('-')
|
||||
//
|
||||
// This ensures filenames work correctly on:
|
||||
// - Windows (strict filename rules)
|
||||
// - Unix/Linux/macOS (forward slashes are path separators)
|
||||
// - All cloud storage providers (S3, Azure Blob, Google Drive)
|
||||
// - Network storage (FTP, SFTP, NAS, rclone)
|
||||
func SanitizeFilename(name string) string {
|
||||
replacer := map[rune]rune{
|
||||
' ': '_',
|
||||
'/': '-',
|
||||
'\\': '-',
|
||||
':': '-',
|
||||
'*': '-',
|
||||
'?': '-',
|
||||
'"': '-',
|
||||
'<': '-',
|
||||
'>': '-',
|
||||
'|': '-',
|
||||
}
|
||||
|
||||
result := make([]rune, 0, len(name))
|
||||
for _, char := range name {
|
||||
if replacement, exists := replacer[char]; exists {
|
||||
result = append(result, replacement)
|
||||
} else {
|
||||
result = append(result, char)
|
||||
}
|
||||
}
|
||||
|
||||
return string(result)
|
||||
}
|
||||
217
backend/internal/util/files/sanitizer_test.go
Normal file
217
backend/internal/util/files/sanitizer_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package files_utils
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_SanitizeFilename_ReplacesSpecialCharacters(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "replaces spaces with underscores",
|
||||
input: "my database name",
|
||||
expected: "my_database_name",
|
||||
},
|
||||
{
|
||||
name: "replaces forward slashes",
|
||||
input: "db/prod/main",
|
||||
expected: "db-prod-main",
|
||||
},
|
||||
{
|
||||
name: "replaces backslashes",
|
||||
input: "db\\prod\\main",
|
||||
expected: "db-prod-main",
|
||||
},
|
||||
{
|
||||
name: "replaces colons",
|
||||
input: "db:production:main",
|
||||
expected: "db-production-main",
|
||||
},
|
||||
{
|
||||
name: "replaces asterisks",
|
||||
input: "db*wildcard",
|
||||
expected: "db-wildcard",
|
||||
},
|
||||
{
|
||||
name: "replaces question marks",
|
||||
input: "db?query",
|
||||
expected: "db-query",
|
||||
},
|
||||
{
|
||||
name: "replaces double quotes",
|
||||
input: "db\"quoted\"name",
|
||||
expected: "db-quoted-name",
|
||||
},
|
||||
{
|
||||
name: "replaces less than signs",
|
||||
input: "db<redirect",
|
||||
expected: "db-redirect",
|
||||
},
|
||||
{
|
||||
name: "replaces greater than signs",
|
||||
input: "db>output",
|
||||
expected: "db-output",
|
||||
},
|
||||
{
|
||||
name: "replaces pipes",
|
||||
input: "db|pipe",
|
||||
expected: "db-pipe",
|
||||
},
|
||||
{
|
||||
name: "replaces multiple different special characters",
|
||||
input: "my db:/backup\\file*2024?",
|
||||
expected: "my_db--backup-file-2024-",
|
||||
},
|
||||
{
|
||||
name: "handles all special characters at once",
|
||||
input: " /\\:*?\"<>|",
|
||||
expected: "_---------",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SanitizeFilename_HandlesEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty string returns empty string",
|
||||
input: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "string with no special characters remains unchanged",
|
||||
input: "simple_database_name",
|
||||
expected: "simple_database_name",
|
||||
},
|
||||
{
|
||||
name: "string with hyphens and underscores remains unchanged",
|
||||
input: "my-database_name-123",
|
||||
expected: "my-database_name-123",
|
||||
},
|
||||
{
|
||||
name: "preserves alphanumeric characters",
|
||||
input: "Database123ABC",
|
||||
expected: "Database123ABC",
|
||||
},
|
||||
{
|
||||
name: "preserves dots and parentheses",
|
||||
input: "db.production.(v2)",
|
||||
expected: "db.production.(v2)",
|
||||
},
|
||||
{
|
||||
name: "handles unicode characters",
|
||||
input: "база_данных_テスト",
|
||||
expected: "база_данных_テスト",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SanitizeFilename_WindowsReservedNames(t *testing.T) {
|
||||
// Windows reserved names are case-insensitive: CON, PRN, AUX, NUL, COM1-COM9, LPT1-LPT9
|
||||
// Our function doesn't handle these specifically because:
|
||||
// 1. Database names in our system are typically lowercase
|
||||
// 2. These are combined with timestamps and UUIDs in filenames (e.g., "CON-20240102-150405-uuid")
|
||||
// 3. The timestamp and UUID suffix make the final filename safe on Windows
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "CON remains as CON (will be safe with timestamp suffix)",
|
||||
input: "CON",
|
||||
expected: "CON",
|
||||
},
|
||||
{
|
||||
name: "PRN remains as PRN (will be safe with timestamp suffix)",
|
||||
input: "PRN",
|
||||
expected: "PRN",
|
||||
},
|
||||
{
|
||||
name: "COM1 remains as COM1 (will be safe with timestamp suffix)",
|
||||
input: "COM1",
|
||||
expected: "COM1",
|
||||
},
|
||||
{
|
||||
name: "handles database name with reserved name as part",
|
||||
input: "my:CON/database",
|
||||
expected: "my-CON-database",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SanitizeFilename_RealWorldExamples(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "production database with environment",
|
||||
input: "prod:main/db",
|
||||
expected: "prod-main-db",
|
||||
},
|
||||
{
|
||||
name: "database with spaces and version",
|
||||
input: "My App Database v2.0",
|
||||
expected: "My_App_Database_v2.0",
|
||||
},
|
||||
{
|
||||
name: "database with special query chars",
|
||||
input: "analytics?region=us*",
|
||||
expected: "analytics-region=us-",
|
||||
},
|
||||
{
|
||||
name: "windows-style path in database name",
|
||||
input: "C:\\databases\\prod",
|
||||
expected: "C--databases-prod",
|
||||
},
|
||||
{
|
||||
name: "unix-style path in database name",
|
||||
input: "/var/lib/postgres/main",
|
||||
expected: "-var-lib-postgres-main",
|
||||
},
|
||||
{
|
||||
name: "database name with quotes",
|
||||
input: "\"production\" database",
|
||||
expected: "-production-_database",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SanitizeFilename(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE backups ADD COLUMN file_name TEXT;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
UPDATE backups SET file_name = id::TEXT WHERE file_name IS NULL;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE backups ALTER COLUMN file_name SET NOT NULL;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE backups DROP COLUMN file_name;
|
||||
-- +goose StatementEnd
|
||||
@@ -2,4 +2,5 @@ MODE=development
|
||||
VITE_GITHUB_CLIENT_ID=
|
||||
VITE_GOOGLE_CLIENT_ID=
|
||||
VITE_IS_EMAIL_CONFIGURED=false
|
||||
VITE_IS_CLOUD=false
|
||||
VITE_IS_CLOUD=false
|
||||
VITE_CLOUDFLARE_TURNSTILE_SITE_KEY=
|
||||
@@ -3,6 +3,7 @@ interface RuntimeConfig {
|
||||
GITHUB_CLIENT_ID?: string;
|
||||
GOOGLE_CLIENT_ID?: string;
|
||||
IS_EMAIL_CONFIGURED?: string;
|
||||
CLOUDFLARE_TURNSTILE_SITE_KEY?: string;
|
||||
}
|
||||
|
||||
declare global {
|
||||
@@ -39,6 +40,11 @@ export const IS_EMAIL_CONFIGURED =
|
||||
window.__RUNTIME_CONFIG__?.IS_EMAIL_CONFIGURED === 'true' ||
|
||||
import.meta.env.VITE_IS_EMAIL_CONFIGURED === 'true';
|
||||
|
||||
export const CLOUDFLARE_TURNSTILE_SITE_KEY =
|
||||
window.__RUNTIME_CONFIG__?.CLOUDFLARE_TURNSTILE_SITE_KEY ||
|
||||
import.meta.env.VITE_CLOUDFLARE_TURNSTILE_SITE_KEY ||
|
||||
'';
|
||||
|
||||
export function getOAuthRedirectUri(): string {
|
||||
return `${window.location.origin}/auth/callback`;
|
||||
}
|
||||
|
||||
@@ -31,10 +31,18 @@ const notifyAuthListeners = () => {
|
||||
};
|
||||
|
||||
export const userApi = {
|
||||
async signUp(signUpRequest: SignUpRequest) {
|
||||
async signUp(signUpRequest: SignUpRequest): Promise<SignInResponse> {
|
||||
const requestOptions: RequestOptions = new RequestOptions();
|
||||
requestOptions.setBody(JSON.stringify(signUpRequest));
|
||||
return apiHelper.fetchPostRaw(`${getApplicationServer()}/api/v1/users/signup`, requestOptions);
|
||||
|
||||
return apiHelper
|
||||
.fetchPostJson(`${getApplicationServer()}/api/v1/users/signup`, requestOptions)
|
||||
.then((response: unknown): SignInResponse => {
|
||||
const typedResponse = response as SignInResponse;
|
||||
saveAuthorizedData(typedResponse.token, typedResponse.userId);
|
||||
notifyAuthListeners();
|
||||
return typedResponse;
|
||||
});
|
||||
},
|
||||
|
||||
async signIn(signInRequest: SignInRequest): Promise<SignInResponse> {
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export interface SendResetPasswordCodeRequest {
|
||||
email: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export interface SignInRequest {
|
||||
email: string;
|
||||
password: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -2,4 +2,5 @@ export interface SignUpRequest {
|
||||
email: string;
|
||||
password: string;
|
||||
name: string;
|
||||
cloudflareTurnstileToken?: string;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
|
||||
interface RequestResetPasswordComponentProps {
|
||||
onSwitchToSignIn?: () => void;
|
||||
@@ -20,6 +23,8 @@ export function RequestResetPasswordComponent({
|
||||
const [error, setError] = useState('');
|
||||
const [successMessage, setSuccessMessage] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateEmail = (): boolean => {
|
||||
if (!email) {
|
||||
setEmailError(true);
|
||||
@@ -42,7 +47,10 @@ export function RequestResetPasswordComponent({
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const response = await userApi.sendResetPasswordCode({ email });
|
||||
const response = await userApi.sendResetPasswordCode({
|
||||
email,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
setSuccessMessage(response.message);
|
||||
|
||||
// After successful code send, switch to reset password form
|
||||
@@ -53,6 +61,7 @@ export function RequestResetPasswordComponent({
|
||||
}, 2000);
|
||||
} catch (e) {
|
||||
setError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
@@ -84,6 +93,8 @@ export function RequestResetPasswordComponent({
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID, IS_EMAIL_CONFIGURED } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
@@ -29,6 +32,8 @@ export function SignInComponent({
|
||||
|
||||
const [signInError, setSignInError] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateFieldsForSignIn = (): boolean => {
|
||||
if (!email) {
|
||||
setEmailError(true);
|
||||
@@ -59,9 +64,11 @@ export function SignInComponent({
|
||||
await userApi.signIn({
|
||||
email,
|
||||
password,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
} catch (e) {
|
||||
setSignInError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
|
||||
setLoading(false);
|
||||
@@ -119,6 +126,8 @@ export function SignInComponent({
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -2,10 +2,13 @@ import { EyeInvisibleOutlined, EyeTwoTone } from '@ant-design/icons';
|
||||
import { App, Button, Input } from 'antd';
|
||||
import { type JSX, useState } from 'react';
|
||||
|
||||
import { useCloudflareTurnstile } from '../../../shared/hooks/useCloudflareTurnstile';
|
||||
|
||||
import { GITHUB_CLIENT_ID, GOOGLE_CLIENT_ID } from '../../../constants';
|
||||
import { userApi } from '../../../entity/users';
|
||||
import { StringUtils } from '../../../shared/lib';
|
||||
import { FormValidator } from '../../../shared/lib/FormValidator';
|
||||
import { CloudflareTurnstileWidget } from '../../../shared/ui/CloudflareTurnstileWidget';
|
||||
import { GithubOAuthComponent } from './oauth/GithubOAuthComponent';
|
||||
import { GoogleOAuthComponent } from './oauth/GoogleOAuthComponent';
|
||||
|
||||
@@ -31,6 +34,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
|
||||
const [signUpError, setSignUpError] = useState('');
|
||||
|
||||
const { token, containerRef, resetCloudflareTurnstile } = useCloudflareTurnstile();
|
||||
|
||||
const validateFieldsForSignUp = (): boolean => {
|
||||
if (!name || name.trim() === '') {
|
||||
setNameError(true);
|
||||
@@ -85,10 +90,11 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
email,
|
||||
password,
|
||||
name,
|
||||
cloudflareTurnstileToken: token,
|
||||
});
|
||||
await userApi.signIn({ email, password });
|
||||
} catch (e) {
|
||||
setSignUpError(StringUtils.capitalizeFirstLetter((e as Error).message));
|
||||
resetCloudflareTurnstile();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -173,6 +179,8 @@ export function SignUpComponent({ onSwitchToSignIn }: SignUpComponentProps): JSX
|
||||
|
||||
<div className="mt-3" />
|
||||
|
||||
<CloudflareTurnstileWidget containerRef={containerRef} />
|
||||
|
||||
<Button
|
||||
disabled={isLoading}
|
||||
loading={isLoading}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import dayjs from 'dayjs';
|
||||
import relativeTime from 'dayjs/plugin/relativeTime';
|
||||
import utc from 'dayjs/plugin/utc';
|
||||
import { StrictMode } from 'react';
|
||||
import { createRoot } from 'react-dom/client';
|
||||
|
||||
import './index.css';
|
||||
@@ -11,8 +10,4 @@ import App from './App.tsx';
|
||||
dayjs.extend(utc);
|
||||
dayjs.extend(relativeTime);
|
||||
|
||||
createRoot(document.getElementById('root')!).render(
|
||||
<StrictMode>
|
||||
<App />
|
||||
</StrictMode>,
|
||||
);
|
||||
createRoot(document.getElementById('root')!).render(<App />);
|
||||
|
||||
116
frontend/src/shared/hooks/useCloudflareTurnstile.ts
Normal file
116
frontend/src/shared/hooks/useCloudflareTurnstile.ts
Normal file
@@ -0,0 +1,116 @@
|
||||
import { useEffect, useRef, useState } from 'react';
|
||||
|
||||
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
|
||||
|
||||
declare global {
|
||||
interface Window {
|
||||
turnstile?: {
|
||||
render: (
|
||||
container: string | HTMLElement,
|
||||
options: {
|
||||
sitekey: string;
|
||||
callback: (token: string) => void;
|
||||
'error-callback'?: () => void;
|
||||
'expired-callback'?: () => void;
|
||||
theme?: 'light' | 'dark' | 'auto';
|
||||
size?: 'normal' | 'compact' | 'flexible';
|
||||
appearance?: 'always' | 'execute' | 'interaction-only';
|
||||
},
|
||||
) => string;
|
||||
reset: (widgetId: string) => void;
|
||||
remove: (widgetId: string) => void;
|
||||
getResponse: (widgetId: string) => string | undefined;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
interface UseCloudflareTurnstileReturn {
|
||||
containerRef: React.RefObject<HTMLDivElement | null>;
|
||||
token: string | undefined;
|
||||
resetCloudflareTurnstile: () => void;
|
||||
}
|
||||
|
||||
const loadCloudflareTurnstileScript = (): Promise<void> => {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
|
||||
return Promise.resolve();
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
if (document.querySelector('script[src*="turnstile"]')) {
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
const script = document.createElement('script');
|
||||
script.src = 'https://challenges.cloudflare.com/turnstile/v0/api.js?render=explicit';
|
||||
script.async = true;
|
||||
script.defer = true;
|
||||
script.onload = () => resolve();
|
||||
script.onerror = () => reject(new Error('Failed to load Cloudflare Turnstile'));
|
||||
document.head.appendChild(script);
|
||||
});
|
||||
};
|
||||
|
||||
export function useCloudflareTurnstile(): UseCloudflareTurnstileReturn {
|
||||
const [token, setToken] = useState<string | undefined>(undefined);
|
||||
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||
const widgetIdRef = useRef<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY || !containerRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadCloudflareTurnstileScript()
|
||||
.then(() => {
|
||||
if (!window.turnstile || !containerRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const widgetId = window.turnstile.render(containerRef.current, {
|
||||
sitekey: CLOUDFLARE_TURNSTILE_SITE_KEY,
|
||||
callback: (receivedToken: string) => {
|
||||
setToken(receivedToken);
|
||||
},
|
||||
'error-callback': () => {
|
||||
setToken(undefined);
|
||||
},
|
||||
'expired-callback': () => {
|
||||
setToken(undefined);
|
||||
},
|
||||
theme: 'auto',
|
||||
size: 'normal',
|
||||
appearance: 'execute',
|
||||
});
|
||||
|
||||
widgetIdRef.current = widgetId;
|
||||
} catch (error) {
|
||||
console.error('Failed to render Cloudflare Turnstile widget:', error);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('Failed to load Cloudflare Turnstile:', error);
|
||||
});
|
||||
|
||||
return () => {
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.remove(widgetIdRef.current);
|
||||
widgetIdRef.current = null;
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
const resetCloudflareTurnstile = () => {
|
||||
if (widgetIdRef.current && window.turnstile) {
|
||||
window.turnstile.reset(widgetIdRef.current);
|
||||
setToken(undefined);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
containerRef,
|
||||
token,
|
||||
resetCloudflareTurnstile,
|
||||
};
|
||||
}
|
||||
17
frontend/src/shared/ui/CloudflareTurnstileWidget.tsx
Normal file
17
frontend/src/shared/ui/CloudflareTurnstileWidget.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import { type JSX } from 'react';
|
||||
|
||||
import { CLOUDFLARE_TURNSTILE_SITE_KEY } from '../../constants';
|
||||
|
||||
interface CloudflareTurnstileWidgetProps {
|
||||
containerRef: React.RefObject<HTMLDivElement | null>;
|
||||
}
|
||||
|
||||
export function CloudflareTurnstileWidget({
|
||||
containerRef,
|
||||
}: CloudflareTurnstileWidgetProps): JSX.Element | null {
|
||||
if (!CLOUDFLARE_TURNSTILE_SITE_KEY) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <div ref={containerRef} className="mb-3" />;
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
export { CloudflareTurnstileWidget } from './CloudflareTurnstileWidget';
|
||||
export { ConfirmationComponent } from './ConfirmationComponent';
|
||||
export { StarButtonComponent } from './StarButtonComponent';
|
||||
export { ThemeToggleComponent } from './ThemeToggleComponent';
|
||||
|
||||
Reference in New Issue
Block a user