diff --git a/agent/.golangci.yml b/agent/.golangci.yml index 9e54a7d..5b27d77 100644 --- a/agent/.golangci.yml +++ b/agent/.golangci.yml @@ -9,13 +9,11 @@ linters: default: standard enable: - funcorder - - gosec - bodyclose - errorlint - gocritic - unconvert - misspell - - dupword - errname - noctx - modernize diff --git a/agent/internal/features/upgrade/upgrader.go b/agent/internal/features/upgrade/upgrader.go index 6dd28c4..f44decf 100644 --- a/agent/internal/features/upgrade/upgrader.go +++ b/agent/internal/features/upgrade/upgrader.go @@ -1,6 +1,7 @@ package upgrade import ( + "context" "encoding/json" "fmt" "io" @@ -74,9 +75,17 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger } func fetchServerVersion(host string, log Logger) (string, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Get(host + "/api/v1/system/version") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil) + if err != nil { + return "", err + } + + resp, err := client.Do(req) if err != nil { log.Warn("Could not reach server for update check, continuing", "error", err) return "", err @@ -104,7 +113,15 @@ func fetchServerVersion(host string, log Logger) (string, error) { func downloadBinary(host, destPath string) error { url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH) - resp, err := http.Get(url) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + + resp, err := http.DefaultClient.Do(req) if err != nil { return err } @@ -126,7 +143,7 @@ func downloadBinary(host, destPath string) error { } func verifyBinary(binaryPath, expectedVersion string) error { - cmd := exec.Command(binaryPath, "version") + cmd := exec.CommandContext(context.Background(), binaryPath, "version") output, err := cmd.Output() if err != nil { diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 18e8b00..d9de713 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -9,13 +9,11 @@ linters: default: standard enable: - funcorder - - gosec - bodyclose - errorlint - gocritic - unconvert - misspell - - dupword - errname - noctx - modernize diff --git a/backend/cmd/main.go b/backend/cmd/main.go index 611eaad..7b0b189 100644 --- a/backend/cmd/main.go +++ b/backend/cmd/main.go @@ -353,7 +353,9 @@ func generateSwaggerDocs(log *slog.Logger) { return } - cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger") + cmd := exec.CommandContext( + context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger", + ) output, err := cmd.CombinedOutput() if err != nil { @@ -367,7 +369,7 @@ func generateSwaggerDocs(log *slog.Logger) { func runMigrations(log *slog.Logger) { log.Info("Running database migrations...") - cmd := exec.Command("goose", "-dir", "./migrations", "up") + cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up") cmd.Env = append( os.Environ(), "GOOSE_DRIVER=postgres", diff --git a/backend/internal/features/audit_logs/repository.go b/backend/internal/features/audit_logs/repository.go index d6264c8..9d1d38e 100644 --- a/backend/internal/features/audit_logs/repository.go +++ b/backend/internal/features/audit_logs/repository.go @@ -38,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal( LEFT JOIN users u ON al.user_id = u.id LEFT JOIN workspaces w ON al.workspace_id = w.id` - args := []interface{}{} + args := []any{} if beforeDate != nil { sql += " WHERE al.created_at < ?" @@ -75,7 +75,7 @@ func (r *AuditLogRepository) GetByUser( LEFT JOIN workspaces w ON al.workspace_id = w.id WHERE al.user_id = ?` - args := []interface{}{userID} + args := []any{userID} if beforeDate != nil { sql += " AND al.created_at < ?" @@ -112,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace( LEFT JOIN workspaces w ON al.workspace_id = w.id WHERE al.workspace_id = ?` - args := []interface{}{workspaceID} + args := []any{workspaceID} if beforeDate != nil { sql += " AND al.created_at < ?" diff --git a/backend/internal/features/backups/backups/backuping/cleaner.go b/backend/internal/features/backups/backups/backuping/cleaner.go index 0eff38a..44c3c1b 100644 --- a/backend/internal/features/backups/backups/backuping/cleaner.go +++ b/backend/internal/features/backups/backups/backuping/cleaner.go @@ -446,22 +446,24 @@ func buildGFSKeepSet( } dailyCutoff := rawDailyCutoff - if weeks > 0 { + switch { + case weeks > 0: dailyCutoff = laterOf(dailyCutoff, weeklyCutoff) - } else if months > 0 { + case months > 0: dailyCutoff = laterOf(dailyCutoff, monthlyCutoff) - } else if years > 0 { + case years > 0: dailyCutoff = laterOf(dailyCutoff, yearlyCutoff) } hourlyCutoff := rawHourlyCutoff - if days > 0 { + switch { + case days > 0: hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff) - } else if weeks > 0 { + case weeks > 0: hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff) - } else if months > 0 { + case months > 0: hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff) - } else if years > 0 { + case years > 0: hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff) } diff --git a/backend/internal/features/backups/backups/common/interfaces.go b/backend/internal/features/backups/backups/common/interfaces.go index c9a0852..6e57475 100644 --- a/backend/internal/features/backups/backups/common/interfaces.go +++ b/backend/internal/features/backups/backups/common/interfaces.go @@ -7,6 +7,10 @@ type CountingWriter struct { BytesWritten int64 } +func NewCountingWriter(writer io.Writer) *CountingWriter { + return &CountingWriter{Writer: writer} +} + func (cw *CountingWriter) Write(p []byte) (n int, err error) { n, err = cw.Writer.Write(p) cw.BytesWritten += int64(n) @@ -16,7 +20,3 @@ func (cw *CountingWriter) Write(p []byte) (n int, err error) { func (cw *CountingWriter) GetBytesWritten() int64 { return cw.BytesWritten } - -func NewCountingWriter(writer io.Writer) *CountingWriter { - return &CountingWriter{Writer: writer} -} diff --git a/backend/internal/features/backups/backups/controllers/controller.go b/backend/internal/features/backups/backups/controllers/controller.go index 2e7d9d2..20aa7a8 100644 --- a/backend/internal/features/backups/backups/controllers/controller.go +++ b/backend/internal/features/backups/backups/controllers/controller.go @@ -2,6 +2,7 @@ package backups_controllers import ( "context" + "errors" "fmt" "io" "net/http" @@ -198,7 +199,7 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) { response, err := c.backupService.GenerateDownloadToken(user, id) if err != nil { - if err == backups_download.ErrDownloadAlreadyInProgress { + if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) { ctx.JSON( http.StatusConflict, gin.H{ @@ -249,7 +250,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) { downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token) if err != nil { - if err == backups_download.ErrDownloadAlreadyInProgress { + if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) { ctx.JSON( http.StatusConflict, gin.H{ diff --git a/backend/internal/features/backups/backups/core/repository.go b/backend/internal/features/backups/backups/core/repository.go index da506d2..c422f80 100644 --- a/backend/internal/features/backups/backups/core/repository.go +++ b/backend/internal/features/backups/backups/core/repository.go @@ -88,7 +88,7 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup, Where("database_id = ?", databaseID). Order("created_at DESC"). First(&backup).Error; err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } diff --git a/backend/internal/features/backups/backups/download/rate_limiter.go b/backend/internal/features/backups/backups/download/rate_limiter.go index df5e73f..54209ea 100644 --- a/backend/internal/features/backups/backups/download/rate_limiter.go +++ b/backend/internal/features/backups/backups/download/rate_limiter.go @@ -66,9 +66,7 @@ func (rl *RateLimiter) Wait(bytes int64) { tokensNeeded := float64(bytes) - rl.availableTokens waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond - if waitTime < time.Millisecond { - waitTime = time.Millisecond - } + waitTime = max(waitTime, time.Millisecond) rl.mu.Unlock() time.Sleep(waitTime) diff --git a/backend/internal/features/backups/backups/download/repository.go b/backend/internal/features/backups/backups/download/repository.go index c1e5707..34e0631 100644 --- a/backend/internal/features/backups/backups/download/repository.go +++ b/backend/internal/features/backups/backups/download/repository.go @@ -3,6 +3,7 @@ package backups_download import ( "crypto/rand" "encoding/base64" + "errors" "time" "github.com/google/uuid" @@ -30,7 +31,7 @@ func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, err Where("token = ?", token). First(&downloadToken).Error if err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } return nil, err diff --git a/backend/internal/features/backups/backups/encryption/decrypting_reader.go b/backend/internal/features/backups/backups/encryption/decrypting_reader.go index 1b9b766..ea61935 100644 --- a/backend/internal/features/backups/backups/encryption/decrypting_reader.go +++ b/backend/internal/features/backups/backups/encryption/decrypting_reader.go @@ -4,6 +4,7 @@ import ( "crypto/aes" "crypto/cipher" "encoding/binary" + "errors" "fmt" "io" @@ -69,7 +70,7 @@ func NewDecryptionReader( func (r *DecryptionReader) Read(p []byte) (n int, err error) { for len(r.buffer) < len(p) && !r.eof { if err := r.readAndDecryptChunk(); err != nil { - if err == io.EOF { + if errors.Is(err, io.EOF) { r.eof = true break } diff --git a/backend/internal/features/backups/backups/services/postgres_wal_service.go b/backend/internal/features/backups/backups/services/postgres_wal_service.go index 7e3bba1..452733e 100644 --- a/backend/internal/features/backups/backups/services/postgres_wal_service.go +++ b/backend/internal/features/backups/backups/services/postgres_wal_service.go @@ -225,6 +225,80 @@ func (s *PostgreWalBackupService) DownloadBackupFile( return s.backupService.GetBackupReader(backupID) } +func (s *PostgreWalBackupService) GetNextFullBackupTime( + database *databases.Database, +) (*backups_dto.GetNextFullBackupTimeResponse, error) { + if err := s.validateWalBackupType(database); err != nil { + return nil, err + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return nil, fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.BackupInterval == nil { + return nil, fmt.Errorf("no backup interval configured for database %s", database.ID) + } + + lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID( + database.ID, + ) + if err != nil { + return nil, fmt.Errorf("failed to query last full backup: %w", err) + } + + var lastBackupTime *time.Time + if lastFullBackup != nil { + lastBackupTime = &lastFullBackup.CreatedAt + } + + now := time.Now().UTC() + nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime) + + return &backups_dto.GetNextFullBackupTimeResponse{ + NextFullBackupTime: nextTime, + }, nil +} + +// ReportError creates a FAILED backup record with the agent's error message. +func (s *PostgreWalBackupService) ReportError( + database *databases.Database, + errorMsg string, +) error { + if err := s.validateWalBackupType(database); err != nil { + return err + } + + backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) + if err != nil { + return fmt.Errorf("failed to get backup config: %w", err) + } + + if backupConfig.Storage == nil { + return fmt.Errorf("no storage configured for database %s", database.ID) + } + + now := time.Now().UTC() + backup := &backups_core.Backup{ + ID: uuid.New(), + DatabaseID: database.ID, + StorageID: backupConfig.Storage.ID, + Status: backups_core.BackupStatusFailed, + FailMessage: &errorMsg, + Encryption: backupConfig.Encryption, + CreatedAt: now, + } + + backup.GenerateFilename(database.Name) + + if err := s.backupRepository.Save(backup); err != nil { + return fmt.Errorf("failed to save error backup record: %w", err) + } + + return nil +} + func (s *PostgreWalBackupService) validateWalChain( databaseID uuid.UUID, incomingSegment string, @@ -432,80 +506,6 @@ func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg } } -func (s *PostgreWalBackupService) GetNextFullBackupTime( - database *databases.Database, -) (*backups_dto.GetNextFullBackupTimeResponse, error) { - if err := s.validateWalBackupType(database); err != nil { - return nil, err - } - - backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) - if err != nil { - return nil, fmt.Errorf("failed to get backup config: %w", err) - } - - if backupConfig.BackupInterval == nil { - return nil, fmt.Errorf("no backup interval configured for database %s", database.ID) - } - - lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID( - database.ID, - ) - if err != nil { - return nil, fmt.Errorf("failed to query last full backup: %w", err) - } - - var lastBackupTime *time.Time - if lastFullBackup != nil { - lastBackupTime = &lastFullBackup.CreatedAt - } - - now := time.Now().UTC() - nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime) - - return &backups_dto.GetNextFullBackupTimeResponse{ - NextFullBackupTime: nextTime, - }, nil -} - -// ReportError creates a FAILED backup record with the agent's error message. -func (s *PostgreWalBackupService) ReportError( - database *databases.Database, - errorMsg string, -) error { - if err := s.validateWalBackupType(database); err != nil { - return err - } - - backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID) - if err != nil { - return fmt.Errorf("failed to get backup config: %w", err) - } - - if backupConfig.Storage == nil { - return fmt.Errorf("no storage configured for database %s", database.ID) - } - - now := time.Now().UTC() - backup := &backups_core.Backup{ - ID: uuid.New(), - DatabaseID: database.ID, - StorageID: backupConfig.Storage.ID, - Status: backups_core.BackupStatusFailed, - FailMessage: &errorMsg, - Encryption: backupConfig.Encryption, - CreatedAt: now, - } - - backup.GenerateFilename(database.Name) - - if err := s.backupRepository.Save(backup); err != nil { - return fmt.Errorf("failed to save error backup record: %w", err) - } - - return nil -} - func (s *PostgreWalBackupService) resolveFullBackup( databaseID uuid.UUID, backupID *uuid.UUID, diff --git a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go index d36ddd1..79ae639 100644 --- a/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mariadb/create_backup_uc.go @@ -548,8 +548,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpErrorMessage( stderrStr, ) - exitErr, ok := waitErr.(*exec.ExitError) - if !ok { + var exitErr *exec.ExitError + if !errors.As(waitErr, &exitErr) { return errors.New(errorMsg) } diff --git a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go index 083fac9..0163f40 100644 --- a/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/mysql/create_backup_uc.go @@ -565,8 +565,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpErrorMessage( stderrStr, ) - exitErr, ok := waitErr.(*exec.ExitError) - if !ok { + var exitErr *exec.ExitError + if !errors.As(waitErr, &exitErr) { return errors.New(errorMsg) } diff --git a/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go b/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go index df0400e..3d4fafd 100644 --- a/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go +++ b/backend/internal/features/backups/backups/usecases/postgresql/create_backup_uc.go @@ -595,8 +595,8 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage( stderrStr := string(stderrOutput) errorMsg := fmt.Sprintf("%s failed: %v – stderr: %s", filepath.Base(pgBin), waitErr, stderrStr) - exitErr, ok := waitErr.(*exec.ExitError) - if !ok { + var exitErr *exec.ExitError + if !errors.As(waitErr, &exitErr) { return errors.New(errorMsg) } diff --git a/backend/internal/features/backups/config/service.go b/backend/internal/features/backups/config/service.go index 61a0493..ff86b6c 100644 --- a/backend/internal/features/backups/config/service.go +++ b/backend/internal/features/backups/config/service.go @@ -214,39 +214,6 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e return s.initializeDefaultConfig(databaseID) } -func (s *BackupConfigService) initializeDefaultConfig( - databaseID uuid.UUID, -) error { - plan, err := s.databasePlanService.GetDatabasePlan(databaseID) - if err != nil { - return err - } - - timeOfDay := "04:00" - - _, err = s.backupConfigRepository.Save(&BackupConfig{ - DatabaseID: databaseID, - IsBackupsEnabled: false, - RetentionPolicyType: RetentionPolicyTypeTimePeriod, - RetentionTimePeriod: plan.MaxStoragePeriod, - MaxBackupSizeMB: plan.MaxBackupSizeMB, - MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB, - BackupInterval: &intervals.Interval{ - Interval: intervals.IntervalDaily, - TimeOfDay: &timeOfDay, - }, - SendNotificationsOn: []BackupNotificationType{ - NotificationBackupFailed, - NotificationBackupSuccess, - }, - IsRetryIfFailed: true, - MaxFailedTriesCount: 3, - Encryption: BackupEncryptionNone, - }) - - return err -} - func (s *BackupConfigService) TransferDatabaseToWorkspace( user *users_models.User, databaseID uuid.UUID, @@ -290,7 +257,8 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace( s.transferNotifiers(user, database, request.TargetWorkspaceID) } - if request.IsTransferWithStorage { + switch { + case request.IsTransferWithStorage: if backupConfig.StorageID == nil { return ErrDatabaseHasNoStorage } @@ -315,7 +283,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace( if err != nil { return err } - } else if request.TargetStorageID != nil { + case request.TargetStorageID != nil: targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID) if err != nil { return err @@ -332,7 +300,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace( if err != nil { return err } - } else { + default: return ErrTargetStorageNotSpecified } @@ -351,6 +319,39 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace( return nil } +func (s *BackupConfigService) initializeDefaultConfig( + databaseID uuid.UUID, +) error { + plan, err := s.databasePlanService.GetDatabasePlan(databaseID) + if err != nil { + return err + } + + timeOfDay := "04:00" + + _, err = s.backupConfigRepository.Save(&BackupConfig{ + DatabaseID: databaseID, + IsBackupsEnabled: false, + RetentionPolicyType: RetentionPolicyTypeTimePeriod, + RetentionTimePeriod: plan.MaxStoragePeriod, + MaxBackupSizeMB: plan.MaxBackupSizeMB, + MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB, + BackupInterval: &intervals.Interval{ + Interval: intervals.IntervalDaily, + TimeOfDay: &timeOfDay, + }, + SendNotificationsOn: []BackupNotificationType{ + NotificationBackupFailed, + NotificationBackupSuccess, + }, + IsRetryIfFailed: true, + MaxFailedTriesCount: 3, + Encryption: BackupEncryptionNone, + }) + + return err +} + func (s *BackupConfigService) transferNotifiers( user *users_models.User, database *databases.Database, diff --git a/backend/internal/features/databases/databases/mariadb/model.go b/backend/internal/features/databases/databases/mariadb/model.go index 54baa4e..aa84f99 100644 --- a/backend/internal/features/databases/databases/mariadb/model.go +++ b/backend/internal/features/databases/databases/mariadb/model.go @@ -391,7 +391,7 @@ func (m *MariadbDatabase) HasPrivilege(priv string) bool { } func HasPrivilege(privileges, priv string) bool { - for _, p := range strings.Split(privileges, ",") { + for p := range strings.SplitSeq(privileges, ",") { if strings.TrimSpace(p) == priv { return true } diff --git a/backend/internal/features/databases/databases/mongodb/model.go b/backend/internal/features/databases/databases/mongodb/model.go index c167ec2..60a44ec 100644 --- a/backend/internal/features/databases/databases/mongodb/model.go +++ b/backend/internal/features/databases/databases/mongodb/model.go @@ -451,6 +451,48 @@ func (m *MongodbDatabase) CreateReadOnlyUser( return "", "", errors.New("failed to generate unique username after 3 attempts") } +// BuildMongodumpURI builds a URI suitable for mongodump (without database in path) +func (m *MongodbDatabase) BuildMongodumpURI(password string) string { + authDB := m.AuthDatabase + if authDB == "" { + authDB = "admin" + } + + extraParams := "" + if m.IsHttps { + extraParams += "&tls=true&tlsInsecure=true" + } + if m.IsDirectConnection { + extraParams += "&directConnection=true" + } + + if m.IsSrv { + return fmt.Sprintf( + "mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s", + url.QueryEscape(m.Username), + url.QueryEscape(password), + m.Host, + authDB, + extraParams, + ) + } + + port := 27017 + if m.Port != nil { + port = *m.Port + } + + return fmt.Sprintf( + "mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s", + url.QueryEscape(m.Username), + url.QueryEscape(password), + m.Host, + port, + authDB, + extraParams, + ) +} + // buildConnectionURI builds a MongoDB connection URI func (m *MongodbDatabase) buildConnectionURI(password string) string { authDB := m.AuthDatabase @@ -495,48 +537,6 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string { ) } -// BuildMongodumpURI builds a URI suitable for mongodump (without database in path) -func (m *MongodbDatabase) BuildMongodumpURI(password string) string { - authDB := m.AuthDatabase - if authDB == "" { - authDB = "admin" - } - - extraParams := "" - if m.IsHttps { - extraParams += "&tls=true&tlsInsecure=true" - } - if m.IsDirectConnection { - extraParams += "&directConnection=true" - } - - if m.IsSrv { - return fmt.Sprintf( - "mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s", - url.QueryEscape(m.Username), - url.QueryEscape(password), - m.Host, - authDB, - extraParams, - ) - } - - port := 27017 - if m.Port != nil { - port = *m.Port - } - - return fmt.Sprintf( - "mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s", - url.QueryEscape(m.Username), - url.QueryEscape(password), - m.Host, - port, - authDB, - extraParams, - ) -} - // detectMongodbVersion gets MongoDB server version from buildInfo command func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.MongodbVersion, error) { adminDB := client.Database("admin") diff --git a/backend/internal/features/databases/databases/postgresql/model.go b/backend/internal/features/databases/databases/postgresql/model.go index 01414e2..bf83b4d 100644 --- a/backend/internal/features/databases/databases/postgresql/model.go +++ b/backend/internal/features/databases/databases/postgresql/model.go @@ -1153,8 +1153,8 @@ func isSupabaseConnection(host, username string) bool { } func extractSupabaseProjectID(username string) string { - if idx := strings.Index(username, "."); idx != -1 { - return username[idx+1:] + if _, after, found := strings.Cut(username, "."); found { + return after } return "" } diff --git a/backend/internal/features/email/email.go b/backend/internal/features/email/email.go index 8ce11b3..22d3f30 100644 --- a/backend/internal/features/email/email.go +++ b/backend/internal/features/email/email.go @@ -1,6 +1,7 @@ package email import ( + "context" "crypto/tls" "fmt" "log/slog" @@ -114,7 +115,7 @@ func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error tlsConfig := &tls.Config{ServerName: s.smtpHost} dialer := &net.Dialer{Timeout: DefaultTimeout} - conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) + conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr) if err != nil { return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err) } diff --git a/backend/internal/features/notifiers/models/discord/model.go b/backend/internal/features/notifiers/models/discord/model.go index 55f8757..d864801 100644 --- a/backend/internal/features/notifiers/models/discord/model.go +++ b/backend/internal/features/notifiers/models/discord/model.go @@ -2,6 +2,7 @@ package discord_notifier import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -47,7 +48,7 @@ func (d *DiscordNotifier) Send( fullMessage = fmt.Sprintf("%s\n\n%s", heading, message) } - payload := map[string]interface{}{ + payload := map[string]any{ "content": fullMessage, } @@ -56,7 +57,7 @@ func (d *DiscordNotifier) Send( return fmt.Errorf("failed to marshal Discord payload: %w", err) } - req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload)) + req, err := http.NewRequestWithContext(context.Background(), "POST", webhookURL, bytes.NewReader(jsonPayload)) if err != nil { return fmt.Errorf("failed to create request: %w", err) } diff --git a/backend/internal/features/notifiers/models/email_notifier/model.go b/backend/internal/features/notifiers/models/email_notifier/model.go index a7f93a3..0e879ce 100644 --- a/backend/internal/features/notifiers/models/email_notifier/model.go +++ b/backend/internal/features/notifiers/models/email_notifier/model.go @@ -1,6 +1,7 @@ package email_notifier import ( + "context" "crypto/tls" "errors" "fmt" @@ -207,7 +208,7 @@ func (e *EmailNotifier) createImplicitTLSClient() (*smtp.Client, func(), error) } dialer := &net.Dialer{Timeout: DefaultTimeout} - conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) + conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr) if err != nil { return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err) } diff --git a/backend/internal/features/notifiers/models/slack/model.go b/backend/internal/features/notifiers/models/slack/model.go index 3b7d336..d42f3ea 100644 --- a/backend/internal/features/notifiers/models/slack/model.go +++ b/backend/internal/features/notifiers/models/slack/model.go @@ -2,6 +2,7 @@ package slack_notifier import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -86,7 +87,8 @@ func (s *SlackNotifier) Send( for { attempts++ - req, err := http.NewRequest( + req, err := http.NewRequestWithContext( + context.Background(), "POST", "https://slack.com/api/chat.postMessage", bytes.NewReader(payload), @@ -136,7 +138,7 @@ func (s *SlackNotifier) Send( if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil { raw, _ := io.ReadAll(resp.Body) - return fmt.Errorf("decode response: %v – raw: %s", err, raw) + return fmt.Errorf("decode response: %w – raw: %s", err, raw) } if !respBody.OK { diff --git a/backend/internal/features/notifiers/models/teams/model.go b/backend/internal/features/notifiers/models/teams/model.go index 34fedf7..c75933e 100644 --- a/backend/internal/features/notifiers/models/teams/model.go +++ b/backend/internal/features/notifiers/models/teams/model.go @@ -2,6 +2,7 @@ package teams_notifier import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -88,7 +89,7 @@ func (n *TeamsNotifier) Send( } body, _ := json.Marshal(p) - req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body)) if err != nil { return err } diff --git a/backend/internal/features/notifiers/models/telegram/model.go b/backend/internal/features/notifiers/models/telegram/model.go index adf9f29..713ea81 100644 --- a/backend/internal/features/notifiers/models/telegram/model.go +++ b/backend/internal/features/notifiers/models/telegram/model.go @@ -1,6 +1,7 @@ package telegram_notifier import ( + "context" "errors" "fmt" "io" @@ -65,7 +66,7 @@ func (t *TelegramNotifier) Send( data.Set("message_thread_id", strconv.FormatInt(*t.ThreadID, 10)) } - req, err := http.NewRequest("POST", apiURL, strings.NewReader(data.Encode())) + req, err := http.NewRequestWithContext(context.Background(), "POST", apiURL, strings.NewReader(data.Encode())) if err != nil { return fmt.Errorf("failed to create request: %w", err) } diff --git a/backend/internal/features/notifiers/models/webhook/model.go b/backend/internal/features/notifiers/models/webhook/model.go index 6dbbe87..d82092b 100644 --- a/backend/internal/features/notifiers/models/webhook/model.go +++ b/backend/internal/features/notifiers/models/webhook/model.go @@ -2,6 +2,7 @@ package webhook_notifier import ( "bytes" + "context" "encoding/json" "errors" "fmt" @@ -146,7 +147,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s url.QueryEscape(message), ) - req, err := http.NewRequest(http.MethodGet, reqURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil) if err != nil { return fmt.Errorf("failed to create GET request: %w", err) } @@ -180,7 +181,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s func (t *WebhookNotifier) sendPOST(webhookURL, heading, message string, logger *slog.Logger) error { body := t.buildRequestBody(heading, message) - req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body)) + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body)) if err != nil { return fmt.Errorf("failed to create POST request: %w", err) } diff --git a/backend/internal/features/notifiers/service.go b/backend/internal/features/notifiers/service.go index 027a51e..c01a50e 100644 --- a/backend/internal/features/notifiers/service.go +++ b/backend/internal/features/notifiers/service.go @@ -343,10 +343,8 @@ func (s *NotifierService) TransferNotifierToWorkspace( return ErrNotifierHasOtherAttachedDatabasesCannotTransfer } } - } else { - if len(attachedDatabasesIDs) > 0 { - return ErrNotifierHasAttachedDatabasesCannotTransfer - } + } else if len(attachedDatabasesIDs) > 0 { + return ErrNotifierHasAttachedDatabasesCannotTransfer } sourceWorkspaceID := existingNotifier.WorkspaceID diff --git a/backend/internal/features/restores/restoring/restorer_test.go b/backend/internal/features/restores/restoring/restorer_test.go index ce315c4..f7881fb 100644 --- a/backend/internal/features/restores/restoring/restorer_test.go +++ b/backend/internal/features/restores/restoring/restorer_test.go @@ -162,3 +162,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) { cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String()) assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts") } + +func stringPtr(s string) *string { + return &s +} diff --git a/backend/internal/features/restores/restoring/scheduler.go b/backend/internal/features/restores/restoring/scheduler.go index 0f36766..80c55a8 100644 --- a/backend/internal/features/restores/restoring/scheduler.go +++ b/backend/internal/features/restores/restoring/scheduler.go @@ -101,27 +101,6 @@ func (s *RestoresScheduler) IsSchedulerRunning() bool { return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold)) } -func (s *RestoresScheduler) failRestoresInProgress() error { - restoresInProgress, err := s.restoreRepository.FindByStatus( - restores_core.RestoreStatusInProgress, - ) - if err != nil { - return err - } - - for _, restore := range restoresInProgress { - failMessage := "Restore failed due to application restart" - restore.FailMessage = &failMessage - restore.Status = restores_core.RestoreStatusFailed - - if err := s.restoreRepository.Save(restore); err != nil { - return err - } - } - - return nil -} - func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error { // If dbCache not provided, try to fetch from DB (for backward compatibility/testing) if dbCache == nil { @@ -326,6 +305,27 @@ func (s *RestoresScheduler) onRestoreCompleted(nodeID, restoreID uuid.UUID) { } } +func (s *RestoresScheduler) failRestoresInProgress() error { + restoresInProgress, err := s.restoreRepository.FindByStatus( + restores_core.RestoreStatusInProgress, + ) + if err != nil { + return err + } + + for _, restore := range restoresInProgress { + failMessage := "Restore failed due to application restart" + restore.FailMessage = &failMessage + restore.Status = restores_core.RestoreStatusFailed + + if err := s.restoreRepository.Save(restore); err != nil { + return err + } + } + + return nil +} + func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error { nodes, err := s.restoreNodesRegistry.GetAvailableNodes() if err != nil { diff --git a/backend/internal/features/restores/restoring/testing.go b/backend/internal/features/restores/restoring/testing.go index 2db3ce3..30002d7 100644 --- a/backend/internal/features/restores/restoring/testing.go +++ b/backend/internal/features/restores/restoring/testing.go @@ -324,7 +324,7 @@ func CreateTestRestore( Port: 5432, Username: "test", Password: "test", - Database: stringPtr("testdb"), + Database: func() *string { s := "testdb"; return &s }(), Version: "16", }, } @@ -336,7 +336,3 @@ func CreateTestRestore( return restore } - -func stringPtr(s string) *string { - return &s -} diff --git a/backend/internal/features/restores/service.go b/backend/internal/features/restores/service.go index 1807a5d..e9ef880 100644 --- a/backend/internal/features/restores/service.go +++ b/backend/internal/features/restores/service.go @@ -199,6 +199,54 @@ func (s *RestoreService) RestoreBackupWithAuth( return nil } +func (s *RestoreService) CancelRestore( + user *users_models.User, + restoreID uuid.UUID, +) error { + restore, err := s.restoreRepository.FindByID(restoreID) + if err != nil { + return err + } + + backup, err := s.backupService.GetBackup(restore.BackupID) + if err != nil { + return err + } + + database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID) + if err != nil { + return err + } + + if database.WorkspaceID == nil { + return errors.New("cannot cancel restore for database without workspace") + } + + canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user) + if err != nil { + return err + } + if !canManage { + return errors.New("insufficient permissions to cancel restore for this database") + } + + if restore.Status != restores_core.RestoreStatusInProgress { + return errors.New("restore is not in progress") + } + + if err := s.taskCancelManager.CancelTask(restoreID); err != nil { + return err + } + + s.auditLogService.WriteAuditLog( + fmt.Sprintf("Restore cancelled for database: %s", database.Name), + &user.ID, + database.WorkspaceID, + ) + + return nil +} + func (s *RestoreService) validateVersionCompatibility( backupDatabase *databases.Database, requestDTO restores_core.RestoreBackupRequest, @@ -295,6 +343,7 @@ func (s *RestoreService) validateVersionCompatibility( `For example, you can restore MongoDB 6.0 backup to MongoDB 6.0, 7.0 or higher. But cannot restore to 5.0`) } } + return nil } @@ -368,51 +417,3 @@ func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error return nil } - -func (s *RestoreService) CancelRestore( - user *users_models.User, - restoreID uuid.UUID, -) error { - restore, err := s.restoreRepository.FindByID(restoreID) - if err != nil { - return err - } - - backup, err := s.backupService.GetBackup(restore.BackupID) - if err != nil { - return err - } - - database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID) - if err != nil { - return err - } - - if database.WorkspaceID == nil { - return errors.New("cannot cancel restore for database without workspace") - } - - canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user) - if err != nil { - return err - } - if !canManage { - return errors.New("insufficient permissions to cancel restore for this database") - } - - if restore.Status != restores_core.RestoreStatusInProgress { - return errors.New("restore is not in progress") - } - - if err := s.taskCancelManager.CancelTask(restoreID); err != nil { - return err - } - - s.auditLogService.WriteAuditLog( - fmt.Sprintf("Restore cancelled for database: %s", database.Name), - &user.ID, - database.WorkspaceID, - ) - - return nil -} diff --git a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go index 6ce9172..b074d3e 100644 --- a/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go +++ b/backend/internal/features/restores/usecases/postgresql/restore_backup_uc.go @@ -304,11 +304,12 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin( _, copyErr := io.Copy(stdinPipe, backupReader) // Close stdin pipe to signal EOF to pg_restore - critical for proper termination closeErr := stdinPipe.Close() - if copyErr != nil { + switch { + case copyErr != nil: copyErrCh <- fmt.Errorf("copy to stdin: %w", copyErr) - } else if closeErr != nil { + case closeErr != nil: copyErrCh <- fmt.Errorf("close stdin: %w", closeErr) - } else { + default: copyErrCh <- nil } }() @@ -764,10 +765,12 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError( ) // Check for specific PostgreSQL error patterns - if exitErr, ok := waitErr.(*exec.ExitError); ok { + var exitErr *exec.ExitError + if errors.As(waitErr, &exitErr) { exitCode := exitErr.ExitCode() - if exitCode == 1 && strings.TrimSpace(stderrStr) == "" { + switch { + case exitCode == 1 && strings.TrimSpace(stderrStr) == "": errorMsg = fmt.Sprintf( "%s failed with exit status 1 but provided no error details. "+ "This often indicates: "+ @@ -782,45 +785,46 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError( pgBin, strings.Join(args, " "), ) - } else if exitCode == -1073741819 { // 0xC0000005 in decimal + case exitCode == -1073741819: // 0xC0000005 in decimal errorMsg = fmt.Sprintf( "%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s", filepath.Base(pgBin), stderrStr, ) - } else if exitCode == 1 || exitCode == 2 { + case exitCode == 1 || exitCode == 2: // Check for common connection and authentication issues - if containsIgnoreCase(stderrStr, "pg_hba.conf") { + switch { + case containsIgnoreCase(stderrStr, "pg_hba.conf"): errorMsg = fmt.Sprintf( "PostgreSQL connection rejected by server configuration (pg_hba.conf). stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") { + case containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth"): errorMsg = fmt.Sprintf( "PostgreSQL authentication failed - no password supplied. stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") { + case containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection"): errorMsg = fmt.Sprintf( "PostgreSQL SSL connection failed. stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") { + case containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused"): errorMsg = fmt.Sprintf( "PostgreSQL connection refused. Check if the server is running and accessible. stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") { + case containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password"): errorMsg = fmt.Sprintf( "PostgreSQL authentication failed. Check username and password. stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "timeout") { + case containsIgnoreCase(stderrStr, "timeout"): errorMsg = fmt.Sprintf( "PostgreSQL connection timeout. stderr: %s", stderrStr, ) - } else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") { + case containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist"): backupDbName := "unknown" if database.Postgresql != nil && database.Postgresql.Database != nil { backupDbName = *database.Postgresql.Database diff --git a/backend/internal/features/storages/models/azure_blob/model.go b/backend/internal/features/storages/models/azure_blob/model.go index 32f5d42..d938119 100644 --- a/backend/internal/features/storages/models/azure_blob/model.go +++ b/backend/internal/features/storages/models/azure_blob/model.go @@ -109,7 +109,7 @@ func (s *AzureBlobStorage) SaveFile( return fmt.Errorf("read error: %w", readErr) } - blockID := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%06d", blockNumber))) + blockID := base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "%06d", blockNumber)) _, err := blockBlobClient.StageBlock( ctx, @@ -337,7 +337,7 @@ func (s *AzureBlobStorage) buildBlobName(fileName string) string { prefix = strings.TrimPrefix(prefix, "/") if !strings.HasSuffix(prefix, "/") { - prefix = prefix + "/" + prefix += "/" } return prefix + fileName diff --git a/backend/internal/features/storages/models/google_drive/model.go b/backend/internal/features/storages/models/google_drive/model.go index 74f2424..5e08776 100644 --- a/backend/internal/features/storages/models/google_drive/model.go +++ b/backend/internal/features/storages/models/google_drive/model.go @@ -167,7 +167,7 @@ func (s *GoogleDriveStorage) GetFile( return err } - resp, err := driveService.Files.Get(fileIDGoogle).Download() + resp, err := driveService.Files.Get(fileIDGoogle).Download() //nolint:bodyclose if err != nil { return fmt.Errorf("failed to download file from Google Drive: %w", err) } @@ -358,7 +358,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth( if strings.Contains(refreshErr.Error(), "invalid_grant") || strings.Contains(refreshErr.Error(), "refresh token") { return fmt.Errorf( - "google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %v", + "google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %w", err, refreshErr, ) @@ -488,7 +488,7 @@ func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) e // maskSensitiveData masks sensitive information in token JSON for logging func maskSensitiveData(tokenJSON string) string { // Replace sensitive values with masked versions - var data map[string]interface{} + var data map[string]any if err := json.Unmarshal([]byte(tokenJSON), &data); err != nil { return "invalid JSON" } diff --git a/backend/internal/features/storages/models/nas/model.go b/backend/internal/features/storages/models/nas/model.go index 2ccec4c..cafab8d 100644 --- a/backend/internal/features/storages/models/nas/model.go +++ b/backend/internal/features/storages/models/nas/model.go @@ -356,7 +356,7 @@ func (n *NASStorage) createConnectionWithContext(ctx context.Context) (net.Conn, ServerName: n.Host, InsecureSkipVerify: false, } - conn, err := tls.DialWithDialer(dialer, "tcp", address, tlsConfig) + conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(ctx, "tcp", address) if err != nil { return nil, fmt.Errorf("failed to create SSL connection to %s: %w", address, err) } diff --git a/backend/internal/features/storages/models/s3/model.go b/backend/internal/features/storages/models/s3/model.go index 5b1e9d5..ddbf911 100644 --- a/backend/internal/features/storages/models/s3/model.go +++ b/backend/internal/features/storages/models/s3/model.go @@ -207,7 +207,7 @@ func (s *S3Storage) GetFile( // Check if the file actually exists by reading the first byte buf := make([]byte, 1) _, readErr := object.Read(buf) - if readErr != nil && readErr != io.EOF { + if readErr != nil && !errors.Is(readErr, io.EOF) { _ = object.Close() return nil, fmt.Errorf("file does not exist in S3: %w", readErr) } @@ -372,7 +372,7 @@ func (s *S3Storage) buildObjectKey(fileName string) string { prefix = strings.TrimPrefix(prefix, "/") if !strings.HasSuffix(prefix, "/") { - prefix = prefix + "/" + prefix += "/" } return prefix + fileName @@ -428,11 +428,11 @@ func (s *S3Storage) getClientParams( endpoint = s.S3Endpoint useSSL = true - if strings.HasPrefix(endpoint, "http://") { + if after, ok := strings.CutPrefix(endpoint, "http://"); ok { useSSL = false - endpoint = strings.TrimPrefix(endpoint, "http://") - } else if strings.HasPrefix(endpoint, "https://") { - endpoint = strings.TrimPrefix(endpoint, "https://") + endpoint = after + } else if after, ok := strings.CutPrefix(endpoint, "https://"); ok { + endpoint = after } if endpoint == "" { diff --git a/backend/internal/features/storages/models/sftp/model.go b/backend/internal/features/storages/models/sftp/model.go index b3f56db..869ada9 100644 --- a/backend/internal/features/storages/models/sftp/model.go +++ b/backend/internal/features/storages/models/sftp/model.go @@ -298,12 +298,7 @@ func (s *SFTPStorage) connectWithContext( authMethods = append(authMethods, ssh.PublicKeys(signer)) } - var hostKeyCallback ssh.HostKeyCallback - if s.SkipHostKeyVerify { - hostKeyCallback = ssh.InsecureIgnoreHostKey() - } else { - hostKeyCallback = ssh.InsecureIgnoreHostKey() - } + hostKeyCallback := ssh.InsecureIgnoreHostKey() config := &ssh.ClientConfig{ User: s.Username, diff --git a/backend/internal/features/storages/service.go b/backend/internal/features/storages/service.go index 46fb3ab..eb567b8 100644 --- a/backend/internal/features/storages/service.go +++ b/backend/internal/features/storages/service.go @@ -364,10 +364,8 @@ func (s *StorageService) TransferStorageToWorkspace( return ErrStorageHasOtherAttachedDatabasesCannotTransfer } } - } else { - if len(attachedDatabasesIDs) > 0 { - return ErrStorageHasAttachedDatabasesCannotTransfer - } + } else if len(attachedDatabasesIDs) > 0 { + return ErrStorageHasAttachedDatabasesCannotTransfer } sourceWorkspaceID := existingStorage.WorkspaceID diff --git a/backend/internal/features/test_once_protection.go b/backend/internal/features/test_once_protection.go index d88a42e..c43864d 100644 --- a/backend/internal/features/test_once_protection.go +++ b/backend/internal/features/test_once_protection.go @@ -1,7 +1,6 @@ package features import ( - "context" "fmt" "sync" "testing" @@ -60,11 +59,9 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) { // Call SetupDependencies concurrently from 10 goroutines for range 10 { - wg.Add(1) - go func() { - defer wg.Done() + wg.Go(func() { audit_logs.SetupDependencies() - }() + }) } wg.Wait() @@ -73,8 +70,7 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) { // Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() // Create a test background service backgroundService := audit_logs.GetAuditLogBackgroundService() @@ -107,8 +103,7 @@ func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) { // Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + ctx := t.Context() scheduler := backuping.GetBackupsScheduler() diff --git a/backend/internal/features/users/repositories/user_repository.go b/backend/internal/features/users/repositories/user_repository.go index 10830ca..f4a9faf 100644 --- a/backend/internal/features/users/repositories/user_repository.go +++ b/backend/internal/features/users/repositories/user_repository.go @@ -1,6 +1,7 @@ package users_repositories import ( + "errors" "fmt" "time" @@ -31,7 +32,7 @@ func (r *UserRepository) GetUserByEmail(email string) (*users_models.User, error var user users_models.User if err := storage.GetDb().Where("email = ?", email).First(&user).Error; err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } @@ -179,7 +180,7 @@ func (r *UserRepository) UpdateUserInfo(userID uuid.UUID, name, email *string) e func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models.User, error) { var user users_models.User err := storage.GetDb().Where("github_oauth_id = ?", githubID).First(&user).Error - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } if err != nil { @@ -191,7 +192,7 @@ func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models. func (r *UserRepository) GetUserByGoogleOAuthID(googleID string) (*users_models.User, error) { var user users_models.User err := storage.GetDb().Where("google_oauth_id = ?", googleID).First(&user).Error - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { return nil, nil } if err != nil { diff --git a/backend/internal/features/users/repositories/users_settings_repository.go b/backend/internal/features/users/repositories/users_settings_repository.go index e15fc10..2df61fb 100644 --- a/backend/internal/features/users/repositories/users_settings_repository.go +++ b/backend/internal/features/users/repositories/users_settings_repository.go @@ -1,6 +1,8 @@ package users_repositories import ( + "errors" + "github.com/google/uuid" "gorm.io/gorm" @@ -14,7 +16,7 @@ func (r *UsersSettingsRepository) GetSettings() (*user_models.UsersSettings, err var settings user_models.UsersSettings if err := storage.GetDb().First(&settings).Error; err != nil { - if err == gorm.ErrRecordNotFound { + if errors.Is(err, gorm.ErrRecordNotFound) { // Create default settings if none exist defaultSettings := &user_models.UsersSettings{ ID: uuid.New(), diff --git a/backend/internal/features/users/services/user_services.go b/backend/internal/features/users/services/user_services.go index f283745..bd76f52 100644 --- a/backend/internal/features/users/services/user_services.go +++ b/backend/internal/features/users/services/user_services.go @@ -685,7 +685,11 @@ func (s *UserService) handleGitHubOAuthWithEndpoint( } client := oauthConfig.Client(context.Background(), token) - resp, err := client.Get(userAPIURL) + githubReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create user info request: %w", err) + } + resp, err := client.Do(githubReq) if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } @@ -754,7 +758,11 @@ func (s *UserService) handleGoogleOAuthWithEndpoint( } client := oauthConfig.Client(context.Background(), token) - resp, err := client.Get(userAPIURL) + googleReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create user info request: %w", err) + } + resp, err := client.Do(googleReq) if err != nil { return nil, fmt.Errorf("failed to get user info: %w", err) } @@ -950,7 +958,11 @@ func (s *UserService) fetchGitHubPrimaryEmail( emailsURL = baseURL + "/user/emails" } - resp, err := client.Get(emailsURL) + emailsReq, err := http.NewRequestWithContext(context.Background(), "GET", emailsURL, nil) + if err != nil { + return "", fmt.Errorf("failed to create user emails request: %w", err) + } + resp, err := client.Do(emailsReq) if err != nil { return "", fmt.Errorf("failed to get user emails: %w", err) } diff --git a/backend/internal/features/users/testing/mocks.go b/backend/internal/features/users/testing/mocks.go index 57fc27d..93b85b3 100644 --- a/backend/internal/features/users/testing/mocks.go +++ b/backend/internal/features/users/testing/mocks.go @@ -13,6 +13,13 @@ type EmailCall struct { Body string } +func NewMockEmailSender() *MockEmailSender { + return &MockEmailSender{ + SentEmails: []EmailCall{}, + ShouldFail: false, + } +} + func (m *MockEmailSender) SendEmail(to, subject, body string) error { m.SentEmails = append(m.SentEmails, EmailCall{ To: to, @@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error { } return nil } - -func NewMockEmailSender() *MockEmailSender { - return &MockEmailSender{ - SentEmails: []EmailCall{}, - ShouldFail: false, - } -} diff --git a/backend/internal/features/workspaces/testing/mocks.go b/backend/internal/features/workspaces/testing/mocks.go index b469bc1..3c9477f 100644 --- a/backend/internal/features/workspaces/testing/mocks.go +++ b/backend/internal/features/workspaces/testing/mocks.go @@ -13,6 +13,13 @@ type EmailCall struct { Body string } +func NewMockEmailSender() *MockEmailSender { + return &MockEmailSender{ + SendEmailCalls: []EmailCall{}, + ShouldFail: false, + } +} + func (m *MockEmailSender) SendEmail(to, subject, body string) error { m.SendEmailCalls = append(m.SendEmailCalls, EmailCall{ To: to, @@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error { } return nil } - -func NewMockEmailSender() *MockEmailSender { - return &MockEmailSender{ - SendEmailCalls: []EmailCall{}, - ShouldFail: false, - } -} diff --git a/backend/internal/features/workspaces/testing/testing.go b/backend/internal/features/workspaces/testing/testing.go index cafaebd..169f669 100644 --- a/backend/internal/features/workspaces/testing/testing.go +++ b/backend/internal/features/workspaces/testing/testing.go @@ -2,6 +2,7 @@ package workspaces_testing import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -433,7 +434,7 @@ func MakeAPIRequest( requestBody = bytes.NewBuffer(nil) } - req, err := http.NewRequest(method, url, requestBody) + req, err := http.NewRequestWithContext(context.Background(), method, url, requestBody) if err != nil { panic(err) } diff --git a/backend/internal/util/cloudflare_turnstile/cloudflare_turnstile_service.go b/backend/internal/util/cloudflare_turnstile/cloudflare_turnstile_service.go index b95a474..0a5e46a 100644 --- a/backend/internal/util/cloudflare_turnstile/cloudflare_turnstile_service.go +++ b/backend/internal/util/cloudflare_turnstile/cloudflare_turnstile_service.go @@ -1,12 +1,14 @@ package cloudflare_turnstile import ( + "context" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" + "strings" "time" ) @@ -42,7 +44,15 @@ func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool, formData.Set("response", token) formData.Set("remoteip", remoteIP) - resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData) + req, err := http.NewRequestWithContext( + context.Background(), "POST", cloudflareTurnstileVerifyURL, strings.NewReader(formData.Encode()), + ) + if err != nil { + return false, fmt.Errorf("failed to create Cloudflare Turnstile request: %w", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := http.DefaultClient.Do(req) if err != nil { return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err) } diff --git a/backend/internal/util/logger/multi_handler.go b/backend/internal/util/logger/multi_handler.go index 8d48a5f..3274c72 100644 --- a/backend/internal/util/logger/multi_handler.go +++ b/backend/internal/util/logger/multi_handler.go @@ -32,7 +32,7 @@ func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error { // Send to VictoriaLogs if configured if h.victoriaLogsWriter != nil { - attrs := make(map[string]interface{}) + attrs := make(map[string]any) record.Attrs(func(a slog.Attr) bool { attrs[a.Key] = a.Value.Any() return true diff --git a/backend/internal/util/logger/victorialogs_writer.go b/backend/internal/util/logger/victorialogs_writer.go index 48190a8..42beb18 100644 --- a/backend/internal/util/logger/victorialogs_writer.go +++ b/backend/internal/util/logger/victorialogs_writer.go @@ -59,7 +59,7 @@ func NewVictoriaLogsWriter(url, username, password string) *VictoriaLogsWriter { return writer } -func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]interface{}) { +func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]any) { entry := logEntry{ Time: time.Now().UTC().Format(time.RFC3339Nano), Message: message, @@ -76,6 +76,27 @@ func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]inter } } +func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) { + w.once.Do(func() { + // Stop accepting new logs + w.cancel() + + // Wait for workers to finish with timeout + done := make(chan struct{}) + go func() { + w.wg.Wait() + close(done) + }() + + select { + case <-done: + w.logger.Info("VictoriaLogs writer shutdown gracefully") + case <-time.After(timeout): + w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost") + } + }) +} + func (w *VictoriaLogsWriter) worker() { defer w.wg.Done() @@ -180,24 +201,3 @@ func (w *VictoriaLogsWriter) flushBatch(batch []logEntry) { w.sendBatch(batch) } } - -func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) { - w.once.Do(func() { - // Stop accepting new logs - w.cancel() - - // Wait for workers to finish with timeout - done := make(chan struct{}) - go func() { - w.wg.Wait() - close(done) - }() - - select { - case <-done: - w.logger.Info("VictoriaLogs writer shutdown gracefully") - case <-time.After(timeout): - w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost") - } - }) -} diff --git a/backend/internal/util/testing/requests.go b/backend/internal/util/testing/requests.go index c3d69d6..55c474f 100644 --- a/backend/internal/util/testing/requests.go +++ b/backend/internal/util/testing/requests.go @@ -2,6 +2,7 @@ package testing import ( "bytes" + "context" "encoding/json" "net/http" "net/http/httptest" @@ -14,7 +15,7 @@ import ( type RequestOptions struct { Method string URL string - Body interface{} + Body any Headers map[string]string AuthToken string ExpectedStatus int @@ -40,7 +41,7 @@ func MakeGetRequestAndUnmarshal( router *gin.Engine, url, authToken string, expectedStatus int, - responseStruct interface{}, + responseStruct any, ) *TestResponse { return makeAuthenticatedRequestAndUnmarshal( t, @@ -58,7 +59,7 @@ func MakePostRequest( t *testing.T, router *gin.Engine, url, authToken string, - body interface{}, + body any, expectedStatus int, ) *TestResponse { return makeAuthenticatedRequest(t, router, "POST", url, authToken, body, expectedStatus) @@ -68,9 +69,9 @@ func MakePostRequestAndUnmarshal( t *testing.T, router *gin.Engine, url, authToken string, - body interface{}, + body any, expectedStatus int, - responseStruct interface{}, + responseStruct any, ) *TestResponse { return makeAuthenticatedRequestAndUnmarshal( t, @@ -88,7 +89,7 @@ func MakePutRequest( t *testing.T, router *gin.Engine, url, authToken string, - body interface{}, + body any, expectedStatus int, ) *TestResponse { return makeAuthenticatedRequest(t, router, "PUT", url, authToken, body, expectedStatus) @@ -98,9 +99,9 @@ func MakePutRequestAndUnmarshal( t *testing.T, router *gin.Engine, url, authToken string, - body interface{}, + body any, expectedStatus int, - responseStruct interface{}, + responseStruct any, ) *TestResponse { return makeAuthenticatedRequestAndUnmarshal( t, @@ -134,7 +135,7 @@ func MakeRequest(t *testing.T, router *gin.Engine, options RequestOptions) *Test requestBody = bytes.NewBuffer(nil) } - req, err := http.NewRequest(options.Method, options.URL, requestBody) + req, err := http.NewRequestWithContext(context.Background(), options.Method, options.URL, requestBody) assert.NoError(t, err, "Failed to create HTTP request") if options.Body != nil { @@ -167,7 +168,7 @@ func makeRequestAndUnmarshal( t *testing.T, router *gin.Engine, options RequestOptions, - responseStruct interface{}, + responseStruct any, ) *TestResponse { response := MakeRequest(t, router, options) @@ -183,7 +184,7 @@ func makeAuthenticatedRequest( t *testing.T, router *gin.Engine, method, url, authToken string, - body interface{}, + body any, expectedStatus int, ) *TestResponse { return MakeRequest(t, router, RequestOptions{ @@ -199,9 +200,9 @@ func makeAuthenticatedRequestAndUnmarshal( t *testing.T, router *gin.Engine, method, url, authToken string, - body interface{}, + body any, expectedStatus int, - responseStruct interface{}, + responseStruct any, ) *TestResponse { return makeRequestAndUnmarshal(t, router, RequestOptions{ Method: method,