mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
REFACTOR (linters): Apply linters fixes
This commit is contained in:
@@ -9,13 +9,11 @@ linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- gosec
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- dupword
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -74,9 +75,17 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger
|
||||
}
|
||||
|
||||
func fetchServerVersion(host string, log Logger) (string, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
|
||||
resp, err := client.Get(host + "/api/v1/system/version")
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
log.Warn("Could not reach server for update check, continuing", "error", err)
|
||||
return "", err
|
||||
@@ -104,7 +113,15 @@ func fetchServerVersion(host string, log Logger) (string, error) {
|
||||
func downloadBinary(host, destPath string) error {
|
||||
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
|
||||
|
||||
resp, err := http.Get(url)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -126,7 +143,7 @@ func downloadBinary(host, destPath string) error {
|
||||
}
|
||||
|
||||
func verifyBinary(binaryPath, expectedVersion string) error {
|
||||
cmd := exec.Command(binaryPath, "version")
|
||||
cmd := exec.CommandContext(context.Background(), binaryPath, "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
|
||||
@@ -9,13 +9,11 @@ linters:
|
||||
default: standard
|
||||
enable:
|
||||
- funcorder
|
||||
- gosec
|
||||
- bodyclose
|
||||
- errorlint
|
||||
- gocritic
|
||||
- unconvert
|
||||
- misspell
|
||||
- dupword
|
||||
- errname
|
||||
- noctx
|
||||
- modernize
|
||||
|
||||
@@ -353,7 +353,9 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger")
|
||||
cmd := exec.CommandContext(
|
||||
context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger",
|
||||
)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
@@ -367,7 +369,7 @@ func generateSwaggerDocs(log *slog.Logger) {
|
||||
func runMigrations(log *slog.Logger) {
|
||||
log.Info("Running database migrations...")
|
||||
|
||||
cmd := exec.Command("goose", "-dir", "./migrations", "up")
|
||||
cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up")
|
||||
cmd.Env = append(
|
||||
os.Environ(),
|
||||
"GOOSE_DRIVER=postgres",
|
||||
|
||||
@@ -38,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal(
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
args := []any{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
@@ -75,7 +75,7 @@ func (r *AuditLogRepository) GetByUser(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
args := []any{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
@@ -112,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace(
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
args := []any{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
|
||||
@@ -446,22 +446,24 @@ func buildGFSKeepSet(
|
||||
}
|
||||
|
||||
dailyCutoff := rawDailyCutoff
|
||||
if weeks > 0 {
|
||||
switch {
|
||||
case weeks > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
|
||||
} else if months > 0 {
|
||||
case months > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
case years > 0:
|
||||
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
hourlyCutoff := rawHourlyCutoff
|
||||
if days > 0 {
|
||||
switch {
|
||||
case days > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
|
||||
} else if weeks > 0 {
|
||||
case weeks > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
|
||||
} else if months > 0 {
|
||||
case months > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
|
||||
} else if years > 0 {
|
||||
case years > 0:
|
||||
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,10 @@ type CountingWriter struct {
|
||||
BytesWritten int64
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = cw.Writer.Write(p)
|
||||
cw.BytesWritten += int64(n)
|
||||
@@ -16,7 +20,3 @@ func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
func (cw *CountingWriter) GetBytesWritten() int64 {
|
||||
return cw.BytesWritten
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package backups_controllers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -198,7 +199,7 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
response, err := c.backupService.GenerateDownloadToken(user, id)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
@@ -249,7 +250,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
|
||||
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
|
||||
@@ -88,7 +88,7 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -66,9 +66,7 @@ func (rl *RateLimiter) Wait(bytes int64) {
|
||||
tokensNeeded := float64(bytes) - rl.availableTokens
|
||||
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
|
||||
|
||||
if waitTime < time.Millisecond {
|
||||
waitTime = time.Millisecond
|
||||
}
|
||||
waitTime = max(waitTime, time.Millisecond)
|
||||
|
||||
rl.mu.Unlock()
|
||||
time.Sleep(waitTime)
|
||||
|
||||
@@ -3,6 +3,7 @@ package backups_download
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -30,7 +31,7 @@ func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, err
|
||||
Where("token = ?", token).
|
||||
First(&downloadToken).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
@@ -69,7 +70,7 @@ func NewDecryptionReader(
|
||||
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
for len(r.buffer) < len(p) && !r.eof {
|
||||
if err := r.readAndDecryptChunk(); err != nil {
|
||||
if err == io.EOF {
|
||||
if errors.Is(err, io.EOF) {
|
||||
r.eof = true
|
||||
break
|
||||
}
|
||||
|
||||
@@ -225,6 +225,80 @@ func (s *PostgreWalBackupService) DownloadBackupFile(
|
||||
return s.backupService.GetBackupReader(backupID)
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) GetNextFullBackupTime(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.BackupInterval == nil {
|
||||
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last full backup: %w", err)
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastFullBackup != nil {
|
||||
lastBackupTime = &lastFullBackup.CreatedAt
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
|
||||
|
||||
return &backups_dto.GetNextFullBackupTimeResponse{
|
||||
NextFullBackupTime: nextTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReportError creates a FAILED backup record with the agent's error message.
|
||||
func (s *PostgreWalBackupService) ReportError(
|
||||
database *databases.Database,
|
||||
errorMsg string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: backupConfig.Storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &errorMsg,
|
||||
Encryption: backupConfig.Encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to save error backup record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) validateWalChain(
|
||||
databaseID uuid.UUID,
|
||||
incomingSegment string,
|
||||
@@ -432,80 +506,6 @@ func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) GetNextFullBackupTime(
|
||||
database *databases.Database,
|
||||
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.BackupInterval == nil {
|
||||
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query last full backup: %w", err)
|
||||
}
|
||||
|
||||
var lastBackupTime *time.Time
|
||||
if lastFullBackup != nil {
|
||||
lastBackupTime = &lastFullBackup.CreatedAt
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
|
||||
|
||||
return &backups_dto.GetNextFullBackupTimeResponse{
|
||||
NextFullBackupTime: nextTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReportError creates a FAILED backup record with the agent's error message.
|
||||
func (s *PostgreWalBackupService) ReportError(
|
||||
database *databases.Database,
|
||||
errorMsg string,
|
||||
) error {
|
||||
if err := s.validateWalBackupType(database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get backup config: %w", err)
|
||||
}
|
||||
|
||||
if backupConfig.Storage == nil {
|
||||
return fmt.Errorf("no storage configured for database %s", database.ID)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
backup := &backups_core.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: backupConfig.Storage.ID,
|
||||
Status: backups_core.BackupStatusFailed,
|
||||
FailMessage: &errorMsg,
|
||||
Encryption: backupConfig.Encryption,
|
||||
CreatedAt: now,
|
||||
}
|
||||
|
||||
backup.GenerateFilename(database.Name)
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
return fmt.Errorf("failed to save error backup record: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgreWalBackupService) resolveFullBackup(
|
||||
databaseID uuid.UUID,
|
||||
backupID *uuid.UUID,
|
||||
|
||||
@@ -548,8 +548,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpErrorMessage(
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -565,8 +565,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpErrorMessage(
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -595,8 +595,8 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf("%s failed: %v – stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
var exitErr *exec.ExitError
|
||||
if !errors.As(waitErr, &exitErr) {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -214,39 +214,6 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
|
||||
return s.initializeDefaultConfig(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
NotificationBackupSuccess,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
@@ -290,7 +257,8 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
s.transferNotifiers(user, database, request.TargetWorkspaceID)
|
||||
}
|
||||
|
||||
if request.IsTransferWithStorage {
|
||||
switch {
|
||||
case request.IsTransferWithStorage:
|
||||
if backupConfig.StorageID == nil {
|
||||
return ErrDatabaseHasNoStorage
|
||||
}
|
||||
@@ -315,7 +283,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if request.TargetStorageID != nil {
|
||||
case request.TargetStorageID != nil:
|
||||
targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -332,7 +300,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
default:
|
||||
return ErrTargetStorageNotSpecified
|
||||
}
|
||||
|
||||
@@ -351,6 +319,39 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
|
||||
_, err = s.backupConfigRepository.Save(&BackupConfig{
|
||||
DatabaseID: databaseID,
|
||||
IsBackupsEnabled: false,
|
||||
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
|
||||
RetentionTimePeriod: plan.MaxStoragePeriod,
|
||||
MaxBackupSizeMB: plan.MaxBackupSizeMB,
|
||||
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
NotificationBackupSuccess,
|
||||
},
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) transferNotifiers(
|
||||
user *users_models.User,
|
||||
database *databases.Database,
|
||||
|
||||
@@ -391,7 +391,7 @@ func (m *MariadbDatabase) HasPrivilege(priv string) bool {
|
||||
}
|
||||
|
||||
func HasPrivilege(privileges, priv string) bool {
|
||||
for _, p := range strings.Split(privileges, ",") {
|
||||
for p := range strings.SplitSeq(privileges, ",") {
|
||||
if strings.TrimSpace(p) == priv {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -451,6 +451,48 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
// BuildMongodumpURI builds a URI suitable for mongodump (without database in path)
|
||||
func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
authDB := m.AuthDatabase
|
||||
if authDB == "" {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
extraParams := ""
|
||||
if m.IsHttps {
|
||||
extraParams += "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
if m.IsDirectConnection {
|
||||
extraParams += "&directConnection=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
extraParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
port,
|
||||
authDB,
|
||||
extraParams,
|
||||
)
|
||||
}
|
||||
|
||||
// buildConnectionURI builds a MongoDB connection URI
|
||||
func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
authDB := m.AuthDatabase
|
||||
@@ -495,48 +537,6 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
|
||||
)
|
||||
}
|
||||
|
||||
// BuildMongodumpURI builds a URI suitable for mongodump (without database in path)
|
||||
func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
|
||||
authDB := m.AuthDatabase
|
||||
if authDB == "" {
|
||||
authDB = "admin"
|
||||
}
|
||||
|
||||
extraParams := ""
|
||||
if m.IsHttps {
|
||||
extraParams += "&tls=true&tlsInsecure=true"
|
||||
}
|
||||
if m.IsDirectConnection {
|
||||
extraParams += "&directConnection=true"
|
||||
}
|
||||
|
||||
if m.IsSrv {
|
||||
return fmt.Sprintf(
|
||||
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
authDB,
|
||||
extraParams,
|
||||
)
|
||||
}
|
||||
|
||||
port := 27017
|
||||
if m.Port != nil {
|
||||
port = *m.Port
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
|
||||
url.QueryEscape(m.Username),
|
||||
url.QueryEscape(password),
|
||||
m.Host,
|
||||
port,
|
||||
authDB,
|
||||
extraParams,
|
||||
)
|
||||
}
|
||||
|
||||
// detectMongodbVersion gets MongoDB server version from buildInfo command
|
||||
func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.MongodbVersion, error) {
|
||||
adminDB := client.Database("admin")
|
||||
|
||||
@@ -1153,8 +1153,8 @@ func isSupabaseConnection(host, username string) bool {
|
||||
}
|
||||
|
||||
func extractSupabaseProjectID(username string) string {
|
||||
if idx := strings.Index(username, "."); idx != -1 {
|
||||
return username[idx+1:]
|
||||
if _, after, found := strings.Cut(username, "."); found {
|
||||
return after
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
@@ -114,7 +115,7 @@ func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error
|
||||
tlsConfig := &tls.Config{ServerName: s.smtpHost}
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package discord_notifier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -47,7 +48,7 @@ func (d *DiscordNotifier) Send(
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
}
|
||||
|
||||
payload := map[string]interface{}{
|
||||
payload := map[string]any{
|
||||
"content": fullMessage,
|
||||
}
|
||||
|
||||
@@ -56,7 +57,7 @@ func (d *DiscordNotifier) Send(
|
||||
return fmt.Errorf("failed to marshal Discord payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", webhookURL, bytes.NewReader(jsonPayload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package email_notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -207,7 +208,7 @@ func (e *EmailNotifier) createImplicitTLSClient() (*smtp.Client, func(), error)
|
||||
}
|
||||
dialer := &net.Dialer{Timeout: DefaultTimeout}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
|
||||
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package slack_notifier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -86,7 +87,8 @@ func (s *SlackNotifier) Send(
|
||||
for {
|
||||
attempts++
|
||||
|
||||
req, err := http.NewRequest(
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(),
|
||||
"POST",
|
||||
"https://slack.com/api/chat.postMessage",
|
||||
bytes.NewReader(payload),
|
||||
@@ -136,7 +138,7 @@ func (s *SlackNotifier) Send(
|
||||
|
||||
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
return fmt.Errorf("decode response: %v – raw: %s", err, raw)
|
||||
return fmt.Errorf("decode response: %w – raw: %s", err, raw)
|
||||
}
|
||||
|
||||
if !respBody.OK {
|
||||
|
||||
@@ -2,6 +2,7 @@ package teams_notifier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -88,7 +89,7 @@ func (n *TeamsNotifier) Send(
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(p)
|
||||
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package telegram_notifier
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -65,7 +66,7 @@ func (t *TelegramNotifier) Send(
|
||||
data.Set("message_thread_id", strconv.FormatInt(*t.ThreadID, 10))
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", apiURL, strings.NewReader(data.Encode()))
|
||||
req, err := http.NewRequestWithContext(context.Background(), "POST", apiURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package webhook_notifier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -146,7 +147,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s
|
||||
url.QueryEscape(message),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create GET request: %w", err)
|
||||
}
|
||||
@@ -180,7 +181,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s
|
||||
func (t *WebhookNotifier) sendPOST(webhookURL, heading, message string, logger *slog.Logger) error {
|
||||
body := t.buildRequestBody(heading, message)
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create POST request: %w", err)
|
||||
}
|
||||
|
||||
@@ -343,10 +343,8 @@ func (s *NotifierService) TransferNotifierToWorkspace(
|
||||
return ErrNotifierHasOtherAttachedDatabasesCannotTransfer
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrNotifierHasAttachedDatabasesCannotTransfer
|
||||
}
|
||||
} else if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrNotifierHasAttachedDatabasesCannotTransfer
|
||||
}
|
||||
|
||||
sourceWorkspaceID := existingNotifier.WorkspaceID
|
||||
|
||||
@@ -162,3 +162,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
|
||||
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
|
||||
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
@@ -101,27 +101,6 @@ func (s *RestoresScheduler) IsSchedulerRunning() bool {
|
||||
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(
|
||||
restores_core.RestoreStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
|
||||
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
|
||||
if dbCache == nil {
|
||||
@@ -326,6 +305,27 @@ func (s *RestoresScheduler) onRestoreCompleted(nodeID, restoreID uuid.UUID) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) failRestoresInProgress() error {
|
||||
restoresInProgress, err := s.restoreRepository.FindByStatus(
|
||||
restores_core.RestoreStatusInProgress,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, restore := range restoresInProgress {
|
||||
failMessage := "Restore failed due to application restart"
|
||||
restore.FailMessage = &failMessage
|
||||
restore.Status = restores_core.RestoreStatusFailed
|
||||
|
||||
if err := s.restoreRepository.Save(restore); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
|
||||
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
|
||||
if err != nil {
|
||||
|
||||
@@ -324,7 +324,7 @@ func CreateTestRestore(
|
||||
Port: 5432,
|
||||
Username: "test",
|
||||
Password: "test",
|
||||
Database: stringPtr("testdb"),
|
||||
Database: func() *string { s := "testdb"; return &s }(),
|
||||
Version: "16",
|
||||
},
|
||||
}
|
||||
@@ -336,7 +336,3 @@ func CreateTestRestore(
|
||||
|
||||
return restore
|
||||
}
|
||||
|
||||
func stringPtr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
|
||||
@@ -199,6 +199,54 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) CancelRestore(
|
||||
user *users_models.User,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel restore for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel restore for this database")
|
||||
}
|
||||
|
||||
if restore.Status != restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is not in progress")
|
||||
}
|
||||
|
||||
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateVersionCompatibility(
|
||||
backupDatabase *databases.Database,
|
||||
requestDTO restores_core.RestoreBackupRequest,
|
||||
@@ -295,6 +343,7 @@ func (s *RestoreService) validateVersionCompatibility(
|
||||
`For example, you can restore MongoDB 6.0 backup to MongoDB 6.0, 7.0 or higher. But cannot restore to 5.0`)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -368,51 +417,3 @@ func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) CancelRestore(
|
||||
user *users_models.User,
|
||||
restoreID uuid.UUID,
|
||||
) error {
|
||||
restore, err := s.restoreRepository.FindByID(restoreID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backup, err := s.backupService.GetBackup(restore.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel restore for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel restore for this database")
|
||||
}
|
||||
|
||||
if restore.Status != restores_core.RestoreStatusInProgress {
|
||||
return errors.New("restore is not in progress")
|
||||
}
|
||||
|
||||
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -304,11 +304,12 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
|
||||
_, copyErr := io.Copy(stdinPipe, backupReader)
|
||||
// Close stdin pipe to signal EOF to pg_restore - critical for proper termination
|
||||
closeErr := stdinPipe.Close()
|
||||
if copyErr != nil {
|
||||
switch {
|
||||
case copyErr != nil:
|
||||
copyErrCh <- fmt.Errorf("copy to stdin: %w", copyErr)
|
||||
} else if closeErr != nil {
|
||||
case closeErr != nil:
|
||||
copyErrCh <- fmt.Errorf("close stdin: %w", closeErr)
|
||||
} else {
|
||||
default:
|
||||
copyErrCh <- nil
|
||||
}
|
||||
}()
|
||||
@@ -764,10 +765,12 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
|
||||
)
|
||||
|
||||
// Check for specific PostgreSQL error patterns
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
var exitErr *exec.ExitError
|
||||
if errors.As(waitErr, &exitErr) {
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
|
||||
switch {
|
||||
case exitCode == 1 && strings.TrimSpace(stderrStr) == "":
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s failed with exit status 1 but provided no error details. "+
|
||||
"This often indicates: "+
|
||||
@@ -782,45 +785,46 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
|
||||
pgBin,
|
||||
strings.Join(args, " "),
|
||||
)
|
||||
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
|
||||
case exitCode == -1073741819: // 0xC0000005 in decimal
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
stderrStr,
|
||||
)
|
||||
} else if exitCode == 1 || exitCode == 2 {
|
||||
case exitCode == 1 || exitCode == 2:
|
||||
// Check for common connection and authentication issues
|
||||
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
|
||||
switch {
|
||||
case containsIgnoreCase(stderrStr, "pg_hba.conf"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection rejected by server configuration (pg_hba.conf). stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
|
||||
case containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed - no password supplied. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
|
||||
case containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL SSL connection failed. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
|
||||
case containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection refused. Check if the server is running and accessible. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
|
||||
case containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "timeout") {
|
||||
case containsIgnoreCase(stderrStr, "timeout"):
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection timeout. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
|
||||
case containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist"):
|
||||
backupDbName := "unknown"
|
||||
if database.Postgresql != nil && database.Postgresql.Database != nil {
|
||||
backupDbName = *database.Postgresql.Database
|
||||
|
||||
@@ -109,7 +109,7 @@ func (s *AzureBlobStorage) SaveFile(
|
||||
return fmt.Errorf("read error: %w", readErr)
|
||||
}
|
||||
|
||||
blockID := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%06d", blockNumber)))
|
||||
blockID := base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "%06d", blockNumber))
|
||||
|
||||
_, err := blockBlobClient.StageBlock(
|
||||
ctx,
|
||||
@@ -337,7 +337,7 @@ func (s *AzureBlobStorage) buildBlobName(fileName string) string {
|
||||
prefix = strings.TrimPrefix(prefix, "/")
|
||||
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix = prefix + "/"
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
return prefix + fileName
|
||||
|
||||
@@ -167,7 +167,7 @@ func (s *GoogleDriveStorage) GetFile(
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := driveService.Files.Get(fileIDGoogle).Download()
|
||||
resp, err := driveService.Files.Get(fileIDGoogle).Download() //nolint:bodyclose
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file from Google Drive: %w", err)
|
||||
}
|
||||
@@ -358,7 +358,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(
|
||||
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
|
||||
strings.Contains(refreshErr.Error(), "refresh token") {
|
||||
return fmt.Errorf(
|
||||
"google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %v",
|
||||
"google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %w",
|
||||
err,
|
||||
refreshErr,
|
||||
)
|
||||
@@ -488,7 +488,7 @@ func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) e
|
||||
// maskSensitiveData masks sensitive information in token JSON for logging
|
||||
func maskSensitiveData(tokenJSON string) string {
|
||||
// Replace sensitive values with masked versions
|
||||
var data map[string]interface{}
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &data); err != nil {
|
||||
return "invalid JSON"
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ func (n *NASStorage) createConnectionWithContext(ctx context.Context) (net.Conn,
|
||||
ServerName: n.Host,
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
|
||||
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSL connection to %s: %w", address, err)
|
||||
}
|
||||
|
||||
@@ -207,7 +207,7 @@ func (s *S3Storage) GetFile(
|
||||
// Check if the file actually exists by reading the first byte
|
||||
buf := make([]byte, 1)
|
||||
_, readErr := object.Read(buf)
|
||||
if readErr != nil && readErr != io.EOF {
|
||||
if readErr != nil && !errors.Is(readErr, io.EOF) {
|
||||
_ = object.Close()
|
||||
return nil, fmt.Errorf("file does not exist in S3: %w", readErr)
|
||||
}
|
||||
@@ -372,7 +372,7 @@ func (s *S3Storage) buildObjectKey(fileName string) string {
|
||||
prefix = strings.TrimPrefix(prefix, "/")
|
||||
|
||||
if !strings.HasSuffix(prefix, "/") {
|
||||
prefix = prefix + "/"
|
||||
prefix += "/"
|
||||
}
|
||||
|
||||
return prefix + fileName
|
||||
@@ -428,11 +428,11 @@ func (s *S3Storage) getClientParams(
|
||||
endpoint = s.S3Endpoint
|
||||
useSSL = true
|
||||
|
||||
if strings.HasPrefix(endpoint, "http://") {
|
||||
if after, ok := strings.CutPrefix(endpoint, "http://"); ok {
|
||||
useSSL = false
|
||||
endpoint = strings.TrimPrefix(endpoint, "http://")
|
||||
} else if strings.HasPrefix(endpoint, "https://") {
|
||||
endpoint = strings.TrimPrefix(endpoint, "https://")
|
||||
endpoint = after
|
||||
} else if after, ok := strings.CutPrefix(endpoint, "https://"); ok {
|
||||
endpoint = after
|
||||
}
|
||||
|
||||
if endpoint == "" {
|
||||
|
||||
@@ -298,12 +298,7 @@ func (s *SFTPStorage) connectWithContext(
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
}
|
||||
|
||||
var hostKeyCallback ssh.HostKeyCallback
|
||||
if s.SkipHostKeyVerify {
|
||||
hostKeyCallback = ssh.InsecureIgnoreHostKey()
|
||||
} else {
|
||||
hostKeyCallback = ssh.InsecureIgnoreHostKey()
|
||||
}
|
||||
hostKeyCallback := ssh.InsecureIgnoreHostKey()
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: s.Username,
|
||||
|
||||
@@ -364,10 +364,8 @@ func (s *StorageService) TransferStorageToWorkspace(
|
||||
return ErrStorageHasOtherAttachedDatabasesCannotTransfer
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrStorageHasAttachedDatabasesCannotTransfer
|
||||
}
|
||||
} else if len(attachedDatabasesIDs) > 0 {
|
||||
return ErrStorageHasAttachedDatabasesCannotTransfer
|
||||
}
|
||||
|
||||
sourceWorkspaceID := existingStorage.WorkspaceID
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package features
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -60,11 +59,9 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
|
||||
|
||||
// Call SetupDependencies concurrently from 10 goroutines
|
||||
for range 10 {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Go(func() {
|
||||
audit_logs.SetupDependencies()
|
||||
}()
|
||||
})
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
@@ -73,8 +70,7 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
|
||||
|
||||
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
|
||||
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
// Create a test background service
|
||||
backgroundService := audit_logs.GetAuditLogBackgroundService()
|
||||
@@ -107,8 +103,7 @@ func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
|
||||
|
||||
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
|
||||
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx := t.Context()
|
||||
|
||||
scheduler := backuping.GetBackupsScheduler()
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package users_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -31,7 +32,7 @@ func (r *UserRepository) GetUserByEmail(email string) (*users_models.User, error
|
||||
var user users_models.User
|
||||
|
||||
if err := storage.GetDb().Where("email = ?", email).First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -179,7 +180,7 @@ func (r *UserRepository) UpdateUserInfo(userID uuid.UUID, name, email *string) e
|
||||
func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models.User, error) {
|
||||
var user users_models.User
|
||||
err := storage.GetDb().Where("github_oauth_id = ?", githubID).First(&user).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
@@ -191,7 +192,7 @@ func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models.
|
||||
func (r *UserRepository) GetUserByGoogleOAuthID(googleID string) (*users_models.User, error) {
|
||||
var user users_models.User
|
||||
err := storage.GetDb().Where("google_oauth_id = ?", googleID).First(&user).Error
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package users_repositories
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -14,7 +16,7 @@ func (r *UsersSettingsRepository) GetSettings() (*user_models.UsersSettings, err
|
||||
var settings user_models.UsersSettings
|
||||
|
||||
if err := storage.GetDb().First(&settings).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// Create default settings if none exist
|
||||
defaultSettings := &user_models.UsersSettings{
|
||||
ID: uuid.New(),
|
||||
|
||||
@@ -685,7 +685,11 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
|
||||
}
|
||||
|
||||
client := oauthConfig.Client(context.Background(), token)
|
||||
resp, err := client.Get(userAPIURL)
|
||||
githubReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user info request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(githubReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
@@ -754,7 +758,11 @@ func (s *UserService) handleGoogleOAuthWithEndpoint(
|
||||
}
|
||||
|
||||
client := oauthConfig.Client(context.Background(), token)
|
||||
resp, err := client.Get(userAPIURL)
|
||||
googleReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create user info request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(googleReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user info: %w", err)
|
||||
}
|
||||
@@ -950,7 +958,11 @@ func (s *UserService) fetchGitHubPrimaryEmail(
|
||||
emailsURL = baseURL + "/user/emails"
|
||||
}
|
||||
|
||||
resp, err := client.Get(emailsURL)
|
||||
emailsReq, err := http.NewRequestWithContext(context.Background(), "GET", emailsURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create user emails request: %w", err)
|
||||
}
|
||||
resp, err := client.Do(emailsReq)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get user emails: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,13 @@ type EmailCall struct {
|
||||
Body string
|
||||
}
|
||||
|
||||
func NewMockEmailSender() *MockEmailSender {
|
||||
return &MockEmailSender{
|
||||
SentEmails: []EmailCall{},
|
||||
ShouldFail: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) SendEmail(to, subject, body string) error {
|
||||
m.SentEmails = append(m.SentEmails, EmailCall{
|
||||
To: to,
|
||||
@@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMockEmailSender() *MockEmailSender {
|
||||
return &MockEmailSender{
|
||||
SentEmails: []EmailCall{},
|
||||
ShouldFail: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,13 @@ type EmailCall struct {
|
||||
Body string
|
||||
}
|
||||
|
||||
func NewMockEmailSender() *MockEmailSender {
|
||||
return &MockEmailSender{
|
||||
SendEmailCalls: []EmailCall{},
|
||||
ShouldFail: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockEmailSender) SendEmail(to, subject, body string) error {
|
||||
m.SendEmailCalls = append(m.SendEmailCalls, EmailCall{
|
||||
To: to,
|
||||
@@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMockEmailSender() *MockEmailSender {
|
||||
return &MockEmailSender{
|
||||
SendEmailCalls: []EmailCall{},
|
||||
ShouldFail: false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package workspaces_testing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -433,7 +434,7 @@ func MakeAPIRequest(
|
||||
requestBody = bytes.NewBuffer(nil)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, requestBody)
|
||||
req, err := http.NewRequestWithContext(context.Background(), method, url, requestBody)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package cloudflare_turnstile
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -42,7 +44,15 @@ func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool,
|
||||
formData.Set("response", token)
|
||||
formData.Set("remoteip", remoteIP)
|
||||
|
||||
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
|
||||
req, err := http.NewRequestWithContext(
|
||||
context.Background(), "POST", cloudflareTurnstileVerifyURL, strings.NewReader(formData.Encode()),
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to create Cloudflare Turnstile request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
|
||||
}
|
||||
|
||||
@@ -32,7 +32,7 @@ func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
|
||||
|
||||
// Send to VictoriaLogs if configured
|
||||
if h.victoriaLogsWriter != nil {
|
||||
attrs := make(map[string]interface{})
|
||||
attrs := make(map[string]any)
|
||||
record.Attrs(func(a slog.Attr) bool {
|
||||
attrs[a.Key] = a.Value.Any()
|
||||
return true
|
||||
|
||||
@@ -59,7 +59,7 @@ func NewVictoriaLogsWriter(url, username, password string) *VictoriaLogsWriter {
|
||||
return writer
|
||||
}
|
||||
|
||||
func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]interface{}) {
|
||||
func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]any) {
|
||||
entry := logEntry{
|
||||
Time: time.Now().UTC().Format(time.RFC3339Nano),
|
||||
Message: message,
|
||||
@@ -76,6 +76,27 @@ func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]inter
|
||||
}
|
||||
}
|
||||
|
||||
func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) {
|
||||
w.once.Do(func() {
|
||||
// Stop accepting new logs
|
||||
w.cancel()
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
w.logger.Info("VictoriaLogs writer shutdown gracefully")
|
||||
case <-time.After(timeout):
|
||||
w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (w *VictoriaLogsWriter) worker() {
|
||||
defer w.wg.Done()
|
||||
|
||||
@@ -180,24 +201,3 @@ func (w *VictoriaLogsWriter) flushBatch(batch []logEntry) {
|
||||
w.sendBatch(batch)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) {
|
||||
w.once.Do(func() {
|
||||
// Stop accepting new logs
|
||||
w.cancel()
|
||||
|
||||
// Wait for workers to finish with timeout
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
w.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
w.logger.Info("VictoriaLogs writer shutdown gracefully")
|
||||
case <-time.After(timeout):
|
||||
w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package testing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
type RequestOptions struct {
|
||||
Method string
|
||||
URL string
|
||||
Body interface{}
|
||||
Body any
|
||||
Headers map[string]string
|
||||
AuthToken string
|
||||
ExpectedStatus int
|
||||
@@ -40,7 +41,7 @@ func MakeGetRequestAndUnmarshal(
|
||||
router *gin.Engine,
|
||||
url, authToken string,
|
||||
expectedStatus int,
|
||||
responseStruct interface{},
|
||||
responseStruct any,
|
||||
) *TestResponse {
|
||||
return makeAuthenticatedRequestAndUnmarshal(
|
||||
t,
|
||||
@@ -58,7 +59,7 @@ func MakePostRequest(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
) *TestResponse {
|
||||
return makeAuthenticatedRequest(t, router, "POST", url, authToken, body, expectedStatus)
|
||||
@@ -68,9 +69,9 @@ func MakePostRequestAndUnmarshal(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
responseStruct interface{},
|
||||
responseStruct any,
|
||||
) *TestResponse {
|
||||
return makeAuthenticatedRequestAndUnmarshal(
|
||||
t,
|
||||
@@ -88,7 +89,7 @@ func MakePutRequest(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
) *TestResponse {
|
||||
return makeAuthenticatedRequest(t, router, "PUT", url, authToken, body, expectedStatus)
|
||||
@@ -98,9 +99,9 @@ func MakePutRequestAndUnmarshal(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
responseStruct interface{},
|
||||
responseStruct any,
|
||||
) *TestResponse {
|
||||
return makeAuthenticatedRequestAndUnmarshal(
|
||||
t,
|
||||
@@ -134,7 +135,7 @@ func MakeRequest(t *testing.T, router *gin.Engine, options RequestOptions) *Test
|
||||
requestBody = bytes.NewBuffer(nil)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(options.Method, options.URL, requestBody)
|
||||
req, err := http.NewRequestWithContext(context.Background(), options.Method, options.URL, requestBody)
|
||||
assert.NoError(t, err, "Failed to create HTTP request")
|
||||
|
||||
if options.Body != nil {
|
||||
@@ -167,7 +168,7 @@ func makeRequestAndUnmarshal(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
options RequestOptions,
|
||||
responseStruct interface{},
|
||||
responseStruct any,
|
||||
) *TestResponse {
|
||||
response := MakeRequest(t, router, options)
|
||||
|
||||
@@ -183,7 +184,7 @@ func makeAuthenticatedRequest(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
method, url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
) *TestResponse {
|
||||
return MakeRequest(t, router, RequestOptions{
|
||||
@@ -199,9 +200,9 @@ func makeAuthenticatedRequestAndUnmarshal(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
method, url, authToken string,
|
||||
body interface{},
|
||||
body any,
|
||||
expectedStatus int,
|
||||
responseStruct interface{},
|
||||
responseStruct any,
|
||||
) *TestResponse {
|
||||
return makeRequestAndUnmarshal(t, router, RequestOptions{
|
||||
Method: method,
|
||||
|
||||
Reference in New Issue
Block a user