REFACTOR (linters): Apply linters fixes

This commit is contained in:
Rostislav Dugin
2026-03-13 18:50:57 +03:00
parent f712e3a437
commit 7e209ff537
50 changed files with 429 additions and 384 deletions

View File

@@ -9,13 +9,11 @@ linters:
default: standard
enable:
- funcorder
- gosec
- bodyclose
- errorlint
- gocritic
- unconvert
- misspell
- dupword
- errname
- noctx
- modernize

View File

@@ -1,6 +1,7 @@
package upgrade
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -74,9 +75,17 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log Logger
}
func fetchServerVersion(host string, log Logger) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Get(host + "/api/v1/system/version")
req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil)
if err != nil {
return "", err
}
resp, err := client.Do(req)
if err != nil {
log.Warn("Could not reach server for update check, continuing", "error", err)
return "", err
@@ -104,7 +113,15 @@ func fetchServerVersion(host string, log Logger) (string, error) {
func downloadBinary(host, destPath string) error {
url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH)
resp, err := http.Get(url)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
@@ -126,7 +143,7 @@ func downloadBinary(host, destPath string) error {
}
func verifyBinary(binaryPath, expectedVersion string) error {
cmd := exec.Command(binaryPath, "version")
cmd := exec.CommandContext(context.Background(), binaryPath, "version")
output, err := cmd.Output()
if err != nil {

View File

@@ -9,13 +9,11 @@ linters:
default: standard
enable:
- funcorder
- gosec
- bodyclose
- errorlint
- gocritic
- unconvert
- misspell
- dupword
- errname
- noctx
- modernize

View File

@@ -353,7 +353,9 @@ func generateSwaggerDocs(log *slog.Logger) {
return
}
cmd := exec.Command("swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger")
cmd := exec.CommandContext(
context.Background(), "swag", "init", "-d", currentDir, "-g", "cmd/main.go", "-o", "swagger",
)
output, err := cmd.CombinedOutput()
if err != nil {
@@ -367,7 +369,7 @@ func generateSwaggerDocs(log *slog.Logger) {
func runMigrations(log *slog.Logger) {
log.Info("Running database migrations...")
cmd := exec.Command("goose", "-dir", "./migrations", "up")
cmd := exec.CommandContext(context.Background(), "goose", "-dir", "./migrations", "up")
cmd.Env = append(
os.Environ(),
"GOOSE_DRIVER=postgres",

View File

@@ -38,7 +38,7 @@ func (r *AuditLogRepository) GetGlobal(
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id`
args := []interface{}{}
args := []any{}
if beforeDate != nil {
sql += " WHERE al.created_at < ?"
@@ -75,7 +75,7 @@ func (r *AuditLogRepository) GetByUser(
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.user_id = ?`
args := []interface{}{userID}
args := []any{userID}
if beforeDate != nil {
sql += " AND al.created_at < ?"
@@ -112,7 +112,7 @@ func (r *AuditLogRepository) GetByWorkspace(
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.workspace_id = ?`
args := []interface{}{workspaceID}
args := []any{workspaceID}
if beforeDate != nil {
sql += " AND al.created_at < ?"

View File

@@ -446,22 +446,24 @@ func buildGFSKeepSet(
}
dailyCutoff := rawDailyCutoff
if weeks > 0 {
switch {
case weeks > 0:
dailyCutoff = laterOf(dailyCutoff, weeklyCutoff)
} else if months > 0 {
case months > 0:
dailyCutoff = laterOf(dailyCutoff, monthlyCutoff)
} else if years > 0 {
case years > 0:
dailyCutoff = laterOf(dailyCutoff, yearlyCutoff)
}
hourlyCutoff := rawHourlyCutoff
if days > 0 {
switch {
case days > 0:
hourlyCutoff = laterOf(hourlyCutoff, dailyCutoff)
} else if weeks > 0 {
case weeks > 0:
hourlyCutoff = laterOf(hourlyCutoff, weeklyCutoff)
} else if months > 0 {
case months > 0:
hourlyCutoff = laterOf(hourlyCutoff, monthlyCutoff)
} else if years > 0 {
case years > 0:
hourlyCutoff = laterOf(hourlyCutoff, yearlyCutoff)
}

View File

@@ -7,6 +7,10 @@ type CountingWriter struct {
BytesWritten int64
}
func NewCountingWriter(writer io.Writer) *CountingWriter {
return &CountingWriter{Writer: writer}
}
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
n, err = cw.Writer.Write(p)
cw.BytesWritten += int64(n)
@@ -16,7 +20,3 @@ func (cw *CountingWriter) Write(p []byte) (n int, err error) {
func (cw *CountingWriter) GetBytesWritten() int64 {
return cw.BytesWritten
}
func NewCountingWriter(writer io.Writer) *CountingWriter {
return &CountingWriter{Writer: writer}
}

View File

@@ -2,6 +2,7 @@ package backups_controllers
import (
"context"
"errors"
"fmt"
"io"
"net/http"
@@ -198,7 +199,7 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
ctx.JSON(
http.StatusConflict,
gin.H{
@@ -249,7 +250,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
if errors.Is(err, backups_download.ErrDownloadAlreadyInProgress) {
ctx.JSON(
http.StatusConflict,
gin.H{

View File

@@ -88,7 +88,7 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&backup).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}

View File

@@ -66,9 +66,7 @@ func (rl *RateLimiter) Wait(bytes int64) {
tokensNeeded := float64(bytes) - rl.availableTokens
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
if waitTime < time.Millisecond {
waitTime = time.Millisecond
}
waitTime = max(waitTime, time.Millisecond)
rl.mu.Unlock()
time.Sleep(waitTime)

View File

@@ -3,6 +3,7 @@ package backups_download
import (
"crypto/rand"
"encoding/base64"
"errors"
"time"
"github.com/google/uuid"
@@ -30,7 +31,7 @@ func (r *DownloadTokenRepository) FindByToken(token string) (*DownloadToken, err
Where("token = ?", token).
First(&downloadToken).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err

View File

@@ -4,6 +4,7 @@ import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -69,7 +70,7 @@ func NewDecryptionReader(
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
for len(r.buffer) < len(p) && !r.eof {
if err := r.readAndDecryptChunk(); err != nil {
if err == io.EOF {
if errors.Is(err, io.EOF) {
r.eof = true
break
}

View File

@@ -225,6 +225,80 @@ func (s *PostgreWalBackupService) DownloadBackupFile(
return s.backupService.GetBackupReader(backupID)
}
func (s *PostgreWalBackupService) GetNextFullBackupTime(
database *databases.Database,
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.BackupInterval == nil {
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
}
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
database.ID,
)
if err != nil {
return nil, fmt.Errorf("failed to query last full backup: %w", err)
}
var lastBackupTime *time.Time
if lastFullBackup != nil {
lastBackupTime = &lastFullBackup.CreatedAt
}
now := time.Now().UTC()
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
return &backups_dto.GetNextFullBackupTimeResponse{
NextFullBackupTime: nextTime,
}, nil
}
// ReportError creates a FAILED backup record with the agent's error message.
func (s *PostgreWalBackupService) ReportError(
database *databases.Database,
errorMsg string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return fmt.Errorf("no storage configured for database %s", database.ID)
}
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: backupConfig.Storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &errorMsg,
Encryption: backupConfig.Encryption,
CreatedAt: now,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to save error backup record: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) validateWalChain(
databaseID uuid.UUID,
incomingSegment string,
@@ -432,80 +506,6 @@ func (s *PostgreWalBackupService) markFailed(backup *backups_core.Backup, errMsg
}
}
func (s *PostgreWalBackupService) GetNextFullBackupTime(
database *databases.Database,
) (*backups_dto.GetNextFullBackupTimeResponse, error) {
if err := s.validateWalBackupType(database); err != nil {
return nil, err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.BackupInterval == nil {
return nil, fmt.Errorf("no backup interval configured for database %s", database.ID)
}
lastFullBackup, err := s.backupRepository.FindLastCompletedFullWalBackupByDatabaseID(
database.ID,
)
if err != nil {
return nil, fmt.Errorf("failed to query last full backup: %w", err)
}
var lastBackupTime *time.Time
if lastFullBackup != nil {
lastBackupTime = &lastFullBackup.CreatedAt
}
now := time.Now().UTC()
nextTime := backupConfig.BackupInterval.NextTriggerTime(now, lastBackupTime)
return &backups_dto.GetNextFullBackupTimeResponse{
NextFullBackupTime: nextTime,
}, nil
}
// ReportError creates a FAILED backup record with the agent's error message.
func (s *PostgreWalBackupService) ReportError(
database *databases.Database,
errorMsg string,
) error {
if err := s.validateWalBackupType(database); err != nil {
return err
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
return fmt.Errorf("failed to get backup config: %w", err)
}
if backupConfig.Storage == nil {
return fmt.Errorf("no storage configured for database %s", database.ID)
}
now := time.Now().UTC()
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: backupConfig.Storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &errorMsg,
Encryption: backupConfig.Encryption,
CreatedAt: now,
}
backup.GenerateFilename(database.Name)
if err := s.backupRepository.Save(backup); err != nil {
return fmt.Errorf("failed to save error backup record: %w", err)
}
return nil
}
func (s *PostgreWalBackupService) resolveFullBackup(
databaseID uuid.UUID,
backupID *uuid.UUID,

View File

@@ -548,8 +548,8 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpErrorMessage(
stderrStr,
)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
var exitErr *exec.ExitError
if !errors.As(waitErr, &exitErr) {
return errors.New(errorMsg)
}

View File

@@ -565,8 +565,8 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpErrorMessage(
stderrStr,
)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
var exitErr *exec.ExitError
if !errors.As(waitErr, &exitErr) {
return errors.New(errorMsg)
}

View File

@@ -595,8 +595,8 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf("%s failed: %v stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
var exitErr *exec.ExitError
if !errors.As(waitErr, &exitErr) {
return errors.New(errorMsg)
}

View File

@@ -214,39 +214,6 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
return s.initializeDefaultConfig(databaseID)
}
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
NotificationBackupSuccess,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
})
return err
}
func (s *BackupConfigService) TransferDatabaseToWorkspace(
user *users_models.User,
databaseID uuid.UUID,
@@ -290,7 +257,8 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
s.transferNotifiers(user, database, request.TargetWorkspaceID)
}
if request.IsTransferWithStorage {
switch {
case request.IsTransferWithStorage:
if backupConfig.StorageID == nil {
return ErrDatabaseHasNoStorage
}
@@ -315,7 +283,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
if err != nil {
return err
}
} else if request.TargetStorageID != nil {
case request.TargetStorageID != nil:
targetStorage, err := s.storageService.GetStorageByID(*request.TargetStorageID)
if err != nil {
return err
@@ -332,7 +300,7 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
if err != nil {
return err
}
} else {
default:
return ErrTargetStorageNotSpecified
}
@@ -351,6 +319,39 @@ func (s *BackupConfigService) TransferDatabaseToWorkspace(
return nil
}
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
NotificationBackupSuccess,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
})
return err
}
func (s *BackupConfigService) transferNotifiers(
user *users_models.User,
database *databases.Database,

View File

@@ -391,7 +391,7 @@ func (m *MariadbDatabase) HasPrivilege(priv string) bool {
}
func HasPrivilege(privileges, priv string) bool {
for _, p := range strings.Split(privileges, ",") {
for p := range strings.SplitSeq(privileges, ",") {
if strings.TrimSpace(p) == priv {
return true
}

View File

@@ -451,6 +451,48 @@ func (m *MongodbDatabase) CreateReadOnlyUser(
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
// BuildMongodumpURI builds a URI suitable for mongodump (without database in path)
func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB := m.AuthDatabase
if authDB == "" {
authDB = "admin"
}
extraParams := ""
if m.IsHttps {
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
authDB,
extraParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
port,
authDB,
extraParams,
)
}
// buildConnectionURI builds a MongoDB connection URI
func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB := m.AuthDatabase
@@ -495,48 +537,6 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
)
}
// BuildMongodumpURI builds a URI suitable for mongodump (without database in path)
func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB := m.AuthDatabase
if authDB == "" {
authDB = "admin"
}
extraParams := ""
if m.IsHttps {
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
authDB,
extraParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
"mongodb://%s:%s@%s:%d/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
port,
authDB,
extraParams,
)
}
// detectMongodbVersion gets MongoDB server version from buildInfo command
func detectMongodbVersion(ctx context.Context, client *mongo.Client) (tools.MongodbVersion, error) {
adminDB := client.Database("admin")

View File

@@ -1153,8 +1153,8 @@ func isSupabaseConnection(host, username string) bool {
}
func extractSupabaseProjectID(username string) string {
if idx := strings.Index(username, "."); idx != -1 {
return username[idx+1:]
if _, after, found := strings.Cut(username, "."); found {
return after
}
return ""
}

View File

@@ -1,6 +1,7 @@
package email
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
@@ -114,7 +115,7 @@ func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error
tlsConfig := &tls.Config{ServerName: s.smtpHost}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}

View File

@@ -2,6 +2,7 @@ package discord_notifier
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -47,7 +48,7 @@ func (d *DiscordNotifier) Send(
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
}
payload := map[string]interface{}{
payload := map[string]any{
"content": fullMessage,
}
@@ -56,7 +57,7 @@ func (d *DiscordNotifier) Send(
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
req, err := http.NewRequestWithContext(context.Background(), "POST", webhookURL, bytes.NewReader(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

View File

@@ -1,6 +1,7 @@
package email_notifier
import (
"context"
"crypto/tls"
"errors"
"fmt"
@@ -207,7 +208,7 @@ func (e *EmailNotifier) createImplicitTLSClient() (*smtp.Client, func(), error)
}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(context.Background(), "tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}

View File

@@ -2,6 +2,7 @@ package slack_notifier
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -86,7 +87,8 @@ func (s *SlackNotifier) Send(
for {
attempts++
req, err := http.NewRequest(
req, err := http.NewRequestWithContext(
context.Background(),
"POST",
"https://slack.com/api/chat.postMessage",
bytes.NewReader(payload),
@@ -136,7 +138,7 @@ func (s *SlackNotifier) Send(
if err := json.NewDecoder(resp.Body).Decode(&respBody); err != nil {
raw, _ := io.ReadAll(resp.Body)
return fmt.Errorf("decode response: %v raw: %s", err, raw)
return fmt.Errorf("decode response: %w raw: %s", err, raw)
}
if !respBody.OK {

View File

@@ -2,6 +2,7 @@ package teams_notifier
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -88,7 +89,7 @@ func (n *TeamsNotifier) Send(
}
body, _ := json.Marshal(p)
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return err
}

View File

@@ -1,6 +1,7 @@
package telegram_notifier
import (
"context"
"errors"
"fmt"
"io"
@@ -65,7 +66,7 @@ func (t *TelegramNotifier) Send(
data.Set("message_thread_id", strconv.FormatInt(*t.ThreadID, 10))
}
req, err := http.NewRequest("POST", apiURL, strings.NewReader(data.Encode()))
req, err := http.NewRequestWithContext(context.Background(), "POST", apiURL, strings.NewReader(data.Encode()))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

View File

@@ -2,6 +2,7 @@ package webhook_notifier
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
@@ -146,7 +147,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s
url.QueryEscape(message),
)
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
if err != nil {
return fmt.Errorf("failed to create GET request: %w", err)
}
@@ -180,7 +181,7 @@ func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *s
func (t *WebhookNotifier) sendPOST(webhookURL, heading, message string, logger *slog.Logger) error {
body := t.buildRequestBody(heading, message)
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create POST request: %w", err)
}

View File

@@ -343,10 +343,8 @@ func (s *NotifierService) TransferNotifierToWorkspace(
return ErrNotifierHasOtherAttachedDatabasesCannotTransfer
}
}
} else {
if len(attachedDatabasesIDs) > 0 {
return ErrNotifierHasAttachedDatabasesCannotTransfer
}
} else if len(attachedDatabasesIDs) > 0 {
return ErrNotifierHasAttachedDatabasesCannotTransfer
}
sourceWorkspaceID := existingNotifier.WorkspaceID

View File

@@ -162,3 +162,7 @@ func Test_MakeRestore_WhenTaskStarts_CacheDeletedImmediately(t *testing.T) {
cachedDBAfter := restoreDatabaseCache.Get(restore.ID.String())
assert.Nil(t, cachedDBAfter, "Cache should be deleted immediately when task starts")
}
func stringPtr(s string) *string {
return &s
}

View File

@@ -101,27 +101,6 @@ func (s *RestoresScheduler) IsSchedulerRunning() bool {
return s.lastCheckTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *RestoresScheduler) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(
restores_core.RestoreStatusInProgress,
)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}
func (s *RestoresScheduler) StartRestore(restoreID uuid.UUID, dbCache *RestoreDatabaseCache) error {
// If dbCache not provided, try to fetch from DB (for backward compatibility/testing)
if dbCache == nil {
@@ -326,6 +305,27 @@ func (s *RestoresScheduler) onRestoreCompleted(nodeID, restoreID uuid.UUID) {
}
}
func (s *RestoresScheduler) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(
restores_core.RestoreStatusInProgress,
)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.FailMessage = &failMessage
restore.Status = restores_core.RestoreStatusFailed
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}
func (s *RestoresScheduler) checkDeadNodesAndFailRestores() error {
nodes, err := s.restoreNodesRegistry.GetAvailableNodes()
if err != nil {

View File

@@ -324,7 +324,7 @@ func CreateTestRestore(
Port: 5432,
Username: "test",
Password: "test",
Database: stringPtr("testdb"),
Database: func() *string { s := "testdb"; return &s }(),
Version: "16",
},
}
@@ -336,7 +336,3 @@ func CreateTestRestore(
return restore
}
func stringPtr(s string) *string {
return &s
}

View File

@@ -199,6 +199,54 @@ func (s *RestoreService) RestoreBackupWithAuth(
return nil
}
func (s *RestoreService) CancelRestore(
user *users_models.User,
restoreID uuid.UUID,
) error {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
return err
}
backup, err := s.backupService.GetBackup(restore.BackupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel restore for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel restore for this database")
}
if restore.Status != restores_core.RestoreStatusInProgress {
return errors.New("restore is not in progress")
}
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}
func (s *RestoreService) validateVersionCompatibility(
backupDatabase *databases.Database,
requestDTO restores_core.RestoreBackupRequest,
@@ -295,6 +343,7 @@ func (s *RestoreService) validateVersionCompatibility(
`For example, you can restore MongoDB 6.0 backup to MongoDB 6.0, 7.0 or higher. But cannot restore to 5.0`)
}
}
return nil
}
@@ -368,51 +417,3 @@ func (s *RestoreService) validateNoParallelRestores(databaseID uuid.UUID) error
return nil
}
func (s *RestoreService) CancelRestore(
user *users_models.User,
restoreID uuid.UUID,
) error {
restore, err := s.restoreRepository.FindByID(restoreID)
if err != nil {
return err
}
backup, err := s.backupService.GetBackup(restore.BackupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel restore for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel restore for this database")
}
if restore.Status != restores_core.RestoreStatusInProgress {
return errors.New("restore is not in progress")
}
if err := s.taskCancelManager.CancelTask(restoreID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Restore cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}

View File

@@ -304,11 +304,12 @@ func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
_, copyErr := io.Copy(stdinPipe, backupReader)
// Close stdin pipe to signal EOF to pg_restore - critical for proper termination
closeErr := stdinPipe.Close()
if copyErr != nil {
switch {
case copyErr != nil:
copyErrCh <- fmt.Errorf("copy to stdin: %w", copyErr)
} else if closeErr != nil {
case closeErr != nil:
copyErrCh <- fmt.Errorf("close stdin: %w", closeErr)
} else {
default:
copyErrCh <- nil
}
}()
@@ -764,10 +765,12 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
)
// Check for specific PostgreSQL error patterns
if exitErr, ok := waitErr.(*exec.ExitError); ok {
var exitErr *exec.ExitError
if errors.As(waitErr, &exitErr) {
exitCode := exitErr.ExitCode()
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
switch {
case exitCode == 1 && strings.TrimSpace(stderrStr) == "":
errorMsg = fmt.Sprintf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
@@ -782,45 +785,46 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
pgBin,
strings.Join(args, " "),
)
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
case exitCode == -1073741819: // 0xC0000005 in decimal
errorMsg = fmt.Sprintf(
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
} else if exitCode == 1 || exitCode == 2 {
case exitCode == 1 || exitCode == 2:
// Check for common connection and authentication issues
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
switch {
case containsIgnoreCase(stderrStr, "pg_hba.conf"):
errorMsg = fmt.Sprintf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
case containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth"):
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed - no password supplied. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
case containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection"):
errorMsg = fmt.Sprintf(
"PostgreSQL SSL connection failed. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
case containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused"):
errorMsg = fmt.Sprintf(
"PostgreSQL connection refused. Check if the server is running and accessible. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
case containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password"):
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "timeout") {
case containsIgnoreCase(stderrStr, "timeout"):
errorMsg = fmt.Sprintf(
"PostgreSQL connection timeout. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
case containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist"):
backupDbName := "unknown"
if database.Postgresql != nil && database.Postgresql.Database != nil {
backupDbName = *database.Postgresql.Database

View File

@@ -109,7 +109,7 @@ func (s *AzureBlobStorage) SaveFile(
return fmt.Errorf("read error: %w", readErr)
}
blockID := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%06d", blockNumber)))
blockID := base64.StdEncoding.EncodeToString(fmt.Appendf(nil, "%06d", blockNumber))
_, err := blockBlobClient.StageBlock(
ctx,
@@ -337,7 +337,7 @@ func (s *AzureBlobStorage) buildBlobName(fileName string) string {
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
prefix += "/"
}
return prefix + fileName

View File

@@ -167,7 +167,7 @@ func (s *GoogleDriveStorage) GetFile(
return err
}
resp, err := driveService.Files.Get(fileIDGoogle).Download()
resp, err := driveService.Files.Get(fileIDGoogle).Download() //nolint:bodyclose
if err != nil {
return fmt.Errorf("failed to download file from Google Drive: %w", err)
}
@@ -358,7 +358,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
strings.Contains(refreshErr.Error(), "refresh token") {
return fmt.Errorf(
"google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %v",
"google drive refresh token has expired. Please re-authenticate and update your token configuration. Original error: %w. Refresh error: %w",
err,
refreshErr,
)
@@ -488,7 +488,7 @@ func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) e
// maskSensitiveData masks sensitive information in token JSON for logging
func maskSensitiveData(tokenJSON string) string {
// Replace sensitive values with masked versions
var data map[string]interface{}
var data map[string]any
if err := json.Unmarshal([]byte(tokenJSON), &data); err != nil {
return "invalid JSON"
}

View File

@@ -356,7 +356,7 @@ func (n *NASStorage) createConnectionWithContext(ctx context.Context) (net.Conn,
ServerName: n.Host,
InsecureSkipVerify: false,
}
conn, err := tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
conn, err := (&tls.Dialer{NetDialer: dialer, Config: tlsConfig}).DialContext(ctx, "tcp", address)
if err != nil {
return nil, fmt.Errorf("failed to create SSL connection to %s: %w", address, err)
}

View File

@@ -207,7 +207,7 @@ func (s *S3Storage) GetFile(
// Check if the file actually exists by reading the first byte
buf := make([]byte, 1)
_, readErr := object.Read(buf)
if readErr != nil && readErr != io.EOF {
if readErr != nil && !errors.Is(readErr, io.EOF) {
_ = object.Close()
return nil, fmt.Errorf("file does not exist in S3: %w", readErr)
}
@@ -372,7 +372,7 @@ func (s *S3Storage) buildObjectKey(fileName string) string {
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
prefix += "/"
}
return prefix + fileName
@@ -428,11 +428,11 @@ func (s *S3Storage) getClientParams(
endpoint = s.S3Endpoint
useSSL = true
if strings.HasPrefix(endpoint, "http://") {
if after, ok := strings.CutPrefix(endpoint, "http://"); ok {
useSSL = false
endpoint = strings.TrimPrefix(endpoint, "http://")
} else if strings.HasPrefix(endpoint, "https://") {
endpoint = strings.TrimPrefix(endpoint, "https://")
endpoint = after
} else if after, ok := strings.CutPrefix(endpoint, "https://"); ok {
endpoint = after
}
if endpoint == "" {

View File

@@ -298,12 +298,7 @@ func (s *SFTPStorage) connectWithContext(
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
var hostKeyCallback ssh.HostKeyCallback
if s.SkipHostKeyVerify {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
} else {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
}
hostKeyCallback := ssh.InsecureIgnoreHostKey()
config := &ssh.ClientConfig{
User: s.Username,

View File

@@ -364,10 +364,8 @@ func (s *StorageService) TransferStorageToWorkspace(
return ErrStorageHasOtherAttachedDatabasesCannotTransfer
}
}
} else {
if len(attachedDatabasesIDs) > 0 {
return ErrStorageHasAttachedDatabasesCannotTransfer
}
} else if len(attachedDatabasesIDs) > 0 {
return ErrStorageHasAttachedDatabasesCannotTransfer
}
sourceWorkspaceID := existingStorage.WorkspaceID

View File

@@ -1,7 +1,6 @@
package features
import (
"context"
"fmt"
"sync"
"testing"
@@ -60,11 +59,9 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
// Call SetupDependencies concurrently from 10 goroutines
for range 10 {
wg.Add(1)
go func() {
defer wg.Done()
wg.Go(func() {
audit_logs.SetupDependencies()
}()
})
}
wg.Wait()
@@ -73,8 +70,7 @@ func Test_SetupDependencies_ConcurrentCalls_Safe(t *testing.T) {
// Test_BackgroundService_Run_CalledTwice_Panics verifies Run() panics on duplicate calls
func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
// Create a test background service
backgroundService := audit_logs.GetAuditLogBackgroundService()
@@ -107,8 +103,7 @@ func Test_BackgroundService_Run_CalledTwice_Panics(t *testing.T) {
// Test_BackupsScheduler_Run_CalledTwice_Panics verifies scheduler panics on duplicate calls
func Test_BackupsScheduler_Run_CalledTwice_Panics(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx := t.Context()
scheduler := backuping.GetBackupsScheduler()

View File

@@ -1,6 +1,7 @@
package users_repositories
import (
"errors"
"fmt"
"time"
@@ -31,7 +32,7 @@ func (r *UserRepository) GetUserByEmail(email string) (*users_models.User, error
var user users_models.User
if err := storage.GetDb().Where("email = ?", email).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
@@ -179,7 +180,7 @@ func (r *UserRepository) UpdateUserInfo(userID uuid.UUID, name, email *string) e
func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models.User, error) {
var user users_models.User
err := storage.GetDb().Where("github_oauth_id = ?", githubID).First(&user).Error
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {
@@ -191,7 +192,7 @@ func (r *UserRepository) GetUserByGitHubOAuthID(githubID string) (*users_models.
func (r *UserRepository) GetUserByGoogleOAuthID(googleID string) (*users_models.User, error) {
var user users_models.User
err := storage.GetDb().Where("google_oauth_id = ?", googleID).First(&user).Error
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
if err != nil {

View File

@@ -1,6 +1,8 @@
package users_repositories
import (
"errors"
"github.com/google/uuid"
"gorm.io/gorm"
@@ -14,7 +16,7 @@ func (r *UsersSettingsRepository) GetSettings() (*user_models.UsersSettings, err
var settings user_models.UsersSettings
if err := storage.GetDb().First(&settings).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
// Create default settings if none exist
defaultSettings := &user_models.UsersSettings{
ID: uuid.New(),

View File

@@ -685,7 +685,11 @@ func (s *UserService) handleGitHubOAuthWithEndpoint(
}
client := oauthConfig.Client(context.Background(), token)
resp, err := client.Get(userAPIURL)
githubReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create user info request: %w", err)
}
resp, err := client.Do(githubReq)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
@@ -754,7 +758,11 @@ func (s *UserService) handleGoogleOAuthWithEndpoint(
}
client := oauthConfig.Client(context.Background(), token)
resp, err := client.Get(userAPIURL)
googleReq, err := http.NewRequestWithContext(context.Background(), "GET", userAPIURL, nil)
if err != nil {
return nil, fmt.Errorf("failed to create user info request: %w", err)
}
resp, err := client.Do(googleReq)
if err != nil {
return nil, fmt.Errorf("failed to get user info: %w", err)
}
@@ -950,7 +958,11 @@ func (s *UserService) fetchGitHubPrimaryEmail(
emailsURL = baseURL + "/user/emails"
}
resp, err := client.Get(emailsURL)
emailsReq, err := http.NewRequestWithContext(context.Background(), "GET", emailsURL, nil)
if err != nil {
return "", fmt.Errorf("failed to create user emails request: %w", err)
}
resp, err := client.Do(emailsReq)
if err != nil {
return "", fmt.Errorf("failed to get user emails: %w", err)
}

View File

@@ -13,6 +13,13 @@ type EmailCall struct {
Body string
}
func NewMockEmailSender() *MockEmailSender {
return &MockEmailSender{
SentEmails: []EmailCall{},
ShouldFail: false,
}
}
func (m *MockEmailSender) SendEmail(to, subject, body string) error {
m.SentEmails = append(m.SentEmails, EmailCall{
To: to,
@@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error {
}
return nil
}
func NewMockEmailSender() *MockEmailSender {
return &MockEmailSender{
SentEmails: []EmailCall{},
ShouldFail: false,
}
}

View File

@@ -13,6 +13,13 @@ type EmailCall struct {
Body string
}
func NewMockEmailSender() *MockEmailSender {
return &MockEmailSender{
SendEmailCalls: []EmailCall{},
ShouldFail: false,
}
}
func (m *MockEmailSender) SendEmail(to, subject, body string) error {
m.SendEmailCalls = append(m.SendEmailCalls, EmailCall{
To: to,
@@ -24,10 +31,3 @@ func (m *MockEmailSender) SendEmail(to, subject, body string) error {
}
return nil
}
func NewMockEmailSender() *MockEmailSender {
return &MockEmailSender{
SendEmailCalls: []EmailCall{},
ShouldFail: false,
}
}

View File

@@ -2,6 +2,7 @@ package workspaces_testing
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
@@ -433,7 +434,7 @@ func MakeAPIRequest(
requestBody = bytes.NewBuffer(nil)
}
req, err := http.NewRequest(method, url, requestBody)
req, err := http.NewRequestWithContext(context.Background(), method, url, requestBody)
if err != nil {
panic(err)
}

View File

@@ -1,12 +1,14 @@
package cloudflare_turnstile
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
@@ -42,7 +44,15 @@ func (s *CloudflareTurnstileService) VerifyToken(token, remoteIP string) (bool,
formData.Set("response", token)
formData.Set("remoteip", remoteIP)
resp, err := http.PostForm(cloudflareTurnstileVerifyURL, formData)
req, err := http.NewRequestWithContext(
context.Background(), "POST", cloudflareTurnstileVerifyURL, strings.NewReader(formData.Encode()),
)
if err != nil {
return false, fmt.Errorf("failed to create Cloudflare Turnstile request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return false, fmt.Errorf("failed to verify Cloudflare Turnstile: %w", err)
}

View File

@@ -32,7 +32,7 @@ func (h *MultiHandler) Handle(ctx context.Context, record slog.Record) error {
// Send to VictoriaLogs if configured
if h.victoriaLogsWriter != nil {
attrs := make(map[string]interface{})
attrs := make(map[string]any)
record.Attrs(func(a slog.Attr) bool {
attrs[a.Key] = a.Value.Any()
return true

View File

@@ -59,7 +59,7 @@ func NewVictoriaLogsWriter(url, username, password string) *VictoriaLogsWriter {
return writer
}
func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]interface{}) {
func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]any) {
entry := logEntry{
Time: time.Now().UTC().Format(time.RFC3339Nano),
Message: message,
@@ -76,6 +76,27 @@ func (w *VictoriaLogsWriter) Write(level, message string, attrs map[string]inter
}
}
func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) {
w.once.Do(func() {
// Stop accepting new logs
w.cancel()
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
w.logger.Info("VictoriaLogs writer shutdown gracefully")
case <-time.After(timeout):
w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost")
}
})
}
func (w *VictoriaLogsWriter) worker() {
defer w.wg.Done()
@@ -180,24 +201,3 @@ func (w *VictoriaLogsWriter) flushBatch(batch []logEntry) {
w.sendBatch(batch)
}
}
func (w *VictoriaLogsWriter) Shutdown(timeout time.Duration) {
w.once.Do(func() {
// Stop accepting new logs
w.cancel()
// Wait for workers to finish with timeout
done := make(chan struct{})
go func() {
w.wg.Wait()
close(done)
}()
select {
case <-done:
w.logger.Info("VictoriaLogs writer shutdown gracefully")
case <-time.After(timeout):
w.logger.Warn("VictoriaLogs writer shutdown timeout, some logs may be lost")
}
})
}

View File

@@ -2,6 +2,7 @@ package testing
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
@@ -14,7 +15,7 @@ import (
type RequestOptions struct {
Method string
URL string
Body interface{}
Body any
Headers map[string]string
AuthToken string
ExpectedStatus int
@@ -40,7 +41,7 @@ func MakeGetRequestAndUnmarshal(
router *gin.Engine,
url, authToken string,
expectedStatus int,
responseStruct interface{},
responseStruct any,
) *TestResponse {
return makeAuthenticatedRequestAndUnmarshal(
t,
@@ -58,7 +59,7 @@ func MakePostRequest(
t *testing.T,
router *gin.Engine,
url, authToken string,
body interface{},
body any,
expectedStatus int,
) *TestResponse {
return makeAuthenticatedRequest(t, router, "POST", url, authToken, body, expectedStatus)
@@ -68,9 +69,9 @@ func MakePostRequestAndUnmarshal(
t *testing.T,
router *gin.Engine,
url, authToken string,
body interface{},
body any,
expectedStatus int,
responseStruct interface{},
responseStruct any,
) *TestResponse {
return makeAuthenticatedRequestAndUnmarshal(
t,
@@ -88,7 +89,7 @@ func MakePutRequest(
t *testing.T,
router *gin.Engine,
url, authToken string,
body interface{},
body any,
expectedStatus int,
) *TestResponse {
return makeAuthenticatedRequest(t, router, "PUT", url, authToken, body, expectedStatus)
@@ -98,9 +99,9 @@ func MakePutRequestAndUnmarshal(
t *testing.T,
router *gin.Engine,
url, authToken string,
body interface{},
body any,
expectedStatus int,
responseStruct interface{},
responseStruct any,
) *TestResponse {
return makeAuthenticatedRequestAndUnmarshal(
t,
@@ -134,7 +135,7 @@ func MakeRequest(t *testing.T, router *gin.Engine, options RequestOptions) *Test
requestBody = bytes.NewBuffer(nil)
}
req, err := http.NewRequest(options.Method, options.URL, requestBody)
req, err := http.NewRequestWithContext(context.Background(), options.Method, options.URL, requestBody)
assert.NoError(t, err, "Failed to create HTTP request")
if options.Body != nil {
@@ -167,7 +168,7 @@ func makeRequestAndUnmarshal(
t *testing.T,
router *gin.Engine,
options RequestOptions,
responseStruct interface{},
responseStruct any,
) *TestResponse {
response := MakeRequest(t, router, options)
@@ -183,7 +184,7 @@ func makeAuthenticatedRequest(
t *testing.T,
router *gin.Engine,
method, url, authToken string,
body interface{},
body any,
expectedStatus int,
) *TestResponse {
return MakeRequest(t, router, RequestOptions{
@@ -199,9 +200,9 @@ func makeAuthenticatedRequestAndUnmarshal(
t *testing.T,
router *gin.Engine,
method, url, authToken string,
body interface{},
body any,
expectedStatus int,
responseStruct interface{},
responseStruct any,
) *TestResponse {
return makeRequestAndUnmarshal(t, router, RequestOptions{
Method: method,