diff --git a/agent/.gitignore b/agent/.gitignore index 3b479da..554a966 100644 --- a/agent/.gitignore +++ b/agent/.gitignore @@ -22,4 +22,5 @@ temp/ valkey-data/ victoria-logs-data/ databasus.json -.test-tmp/ \ No newline at end of file +.test-tmp/ +databasus.log \ No newline at end of file diff --git a/agent/cmd/main.go b/agent/cmd/main.go index 38b90c5..c4b18b9 100644 --- a/agent/cmd/main.go +++ b/agent/cmd/main.go @@ -9,6 +9,7 @@ import ( "strings" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" "databasus-agent/internal/features/start" "databasus-agent/internal/features/upgrade" "databasus-agent/internal/logger" @@ -25,6 +26,8 @@ func main() { switch os.Args[1] { case "start": runStart(os.Args[2:]) + case "_run": + runDaemon(os.Args[2:]) case "stop": runStop() case "status": @@ -43,7 +46,6 @@ func main() { func runStart(args []string) { fs := flag.NewFlagSet("start", flag.ExitOnError) - isDebug := fs.Bool("debug", false, "Enable debug logging") isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check") cfg := &config.Config{} @@ -53,26 +55,51 @@ func runStart(args []string) { fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err) } - logger.Init(*isDebug) log := logger.GetLogger() isDev := checkIsDevelopment() runUpdateCheck(cfg.DatabasusHost, *isSkipUpdate, isDev, log) - if err := start.Run(cfg, log); err != nil { + if err := start.Start(cfg, log); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } +func runDaemon(args []string) { + fs := flag.NewFlagSet("_run", flag.ExitOnError) + + if err := fs.Parse(args); err != nil { + os.Exit(1) + } + + log := logger.GetLogger() + + cfg := &config.Config{} + cfg.LoadFromJSON() + + if err := start.RunDaemon(cfg, log); err != nil { + log.Error("Agent exited with error", "error", err) + os.Exit(1) + } +} + func runStop() { - logger.Init(false) - logger.GetLogger().Info("stop: stub — not yet implemented") + log := logger.GetLogger() + + if err := start.Stop(log); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } } func runStatus() { - logger.Init(false) - logger.GetLogger().Info("status: stub — not yet implemented") + log := logger.GetLogger() + + if err := start.Status(log); err != nil { + fmt.Fprintf(os.Stderr, "Error: %v\n", err) + os.Exit(1) + } } func runRestore(args []string) { @@ -82,7 +109,6 @@ func runRestore(args []string) { backupID := fs.String("backup-id", "", "Full backup UUID (optional)") targetTime := fs.String("target-time", "", "PITR target time in RFC3339 (optional)") isYes := fs.Bool("yes", false, "Skip confirmation prompt") - isDebug := fs.Bool("debug", false, "Enable debug logging") isSkipUpdate := fs.Bool("skip-update", false, "Skip auto-update check") cfg := &config.Config{} @@ -92,7 +118,6 @@ func runRestore(args []string) { fmt.Fprintf(os.Stderr, "Failed to save config: %v\n", err) } - logger.Init(*isDebug) log := logger.GetLogger() isDev := checkIsDevelopment() @@ -126,7 +151,9 @@ func runUpdateCheck(host string, isSkipUpdate, isDev bool, log *slog.Logger) { return } - if err := upgrade.CheckAndUpdate(host, Version, isDev, log); err != nil { + apiClient := api.NewClient(host, "", log) + + if err := upgrade.CheckAndUpdate(apiClient, Version, isDev, log); err != nil { log.Error("Auto-update failed", "error", err) os.Exit(1) } diff --git a/agent/go.mod b/agent/go.mod index a4259cb..ec3248c 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -3,6 +3,7 @@ module databasus-agent go 1.26.1 require ( + github.com/go-resty/resty/v2 v2.17.2 github.com/jackc/pgx/v5 v5.8.0 github.com/klauspost/compress v1.18.4 github.com/stretchr/testify v1.11.1 @@ -15,6 +16,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect + golang.org/x/net v0.43.0 // indirect golang.org/x/text v0.29.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/agent/go.sum b/agent/go.sum index fc4a5c2..76740e0 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -2,6 +2,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-resty/resty/v2 v2.17.2 h1:FQW5oHYcIlkCNrMD2lloGScxcHJ0gkjshV3qcQAyHQk= +github.com/go-resty/resty/v2 v2.17.2/go.mod h1:kCKZ3wWmwJaNc7S29BRtUhJwy7iqmn+2mLtQrOyQlVA= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -25,10 +27,14 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/text v0.29.0 h1:1neNs90w9YzJ9BocxfsQNHKuAT4pkghyXc4nhZ6sJvk= golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/agent/internal/config/config.go b/agent/internal/config/config.go index 17709de..174d7af 100644 --- a/agent/internal/config/config.go +++ b/agent/internal/config/config.go @@ -73,6 +73,11 @@ func (c *Config) SaveToJSON() error { return os.WriteFile(configFileName, data, 0o644) } +func (c *Config) LoadFromJSON() { + c.loadFromJSON() + c.applyDefaults() +} + func (c *Config) loadFromJSON() { data, err := os.ReadFile(configFileName) if err != nil { diff --git a/agent/internal/features/api/api.go b/agent/internal/features/api/api.go new file mode 100644 index 0000000..322a619 --- /dev/null +++ b/agent/internal/features/api/api.go @@ -0,0 +1,218 @@ +package api + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "time" + + "github.com/go-resty/resty/v2" +) + +const ( + chainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup" + nextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time" + uploadPath = "/api/v1/backups/postgres/wal/upload" + reportErrorPath = "/api/v1/backups/postgres/wal/error" + versionPath = "/api/v1/system/version" + agentBinaryPath = "/api/v1/system/agent" + + apiCallTimeout = 30 * time.Second + maxRetryAttempts = 3 + retryBaseDelay = 1 * time.Second +) + +type Client struct { + json *resty.Client + stream *resty.Client + host string + log *slog.Logger +} + +func NewClient(host, token string, log *slog.Logger) *Client { + setAuth := func(_ *resty.Client, req *resty.Request) error { + if token != "" { + req.SetHeader("Authorization", token) + } + + return nil + } + + jsonClient := resty.New(). + SetTimeout(apiCallTimeout). + SetRetryCount(maxRetryAttempts - 1). + SetRetryWaitTime(retryBaseDelay). + SetRetryMaxWaitTime(4 * retryBaseDelay). + AddRetryCondition(func(resp *resty.Response, err error) bool { + return err != nil || resp.StatusCode() >= 500 + }). + OnBeforeRequest(setAuth) + + streamClient := resty.New(). + OnBeforeRequest(setAuth) + + return &Client{ + json: jsonClient, + stream: streamClient, + host: host, + log: log, + } +} + +func (c *Client) CheckWalChainValidity(ctx context.Context) (*WalChainValidityResponse, error) { + var resp WalChainValidityResponse + + _, err := c.json.R(). + SetContext(ctx). + SetResult(&resp). + Get(c.buildURL(chainValidPath)) + if err != nil { + return nil, err + } + + return &resp, nil +} + +func (c *Client) GetNextFullBackupTime(ctx context.Context) (*NextFullBackupTimeResponse, error) { + var resp NextFullBackupTimeResponse + + _, err := c.json.R(). + SetContext(ctx). + SetResult(&resp). + Get(c.buildURL(nextBackupTimePath)) + if err != nil { + return nil, err + } + + return &resp, nil +} + +func (c *Client) ReportBackupError(ctx context.Context, errMsg string) error { + _, err := c.json.R(). + SetContext(ctx). + SetBody(reportErrorRequest{Error: errMsg}). + Post(c.buildURL(reportErrorPath)) + + return err +} + +func (c *Client) UploadBasebackup( + ctx context.Context, + startSegment string, + stopSegment string, + body io.Reader, +) error { + url := fmt.Sprintf("%s?fullBackupWalStartSegment=%s&fullBackupWalStopSegment=%s", + c.buildURL(uploadPath), startSegment, stopSegment, + ) + + resp, err := c.stream.R(). + SetContext(ctx). + SetBody(body). + SetHeader("Content-Type", "application/octet-stream"). + SetHeader("X-Upload-Type", "basebackup"). + SetDoNotParseResponse(true). + Post(url) + if err != nil { + return fmt.Errorf("upload request: %w", err) + } + defer func() { _ = resp.RawBody().Close() }() + + if resp.StatusCode() != http.StatusNoContent { + respBody, _ := io.ReadAll(resp.RawBody()) + + return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody)) + } + + return nil +} + +func (c *Client) UploadWalSegment( + ctx context.Context, + segmentName string, + body io.Reader, +) (*UploadWalSegmentResult, error) { + resp, err := c.stream.R(). + SetContext(ctx). + SetBody(body). + SetHeader("Content-Type", "application/octet-stream"). + SetHeader("X-Upload-Type", "wal"). + SetHeader("X-Wal-Segment-Name", segmentName). + SetDoNotParseResponse(true). + Post(c.buildURL(uploadPath)) + if err != nil { + return nil, fmt.Errorf("upload request: %w", err) + } + defer func() { _ = resp.RawBody().Close() }() + + switch resp.StatusCode() { + case http.StatusNoContent: + return &UploadWalSegmentResult{IsGapDetected: false}, nil + + case http.StatusConflict: + var errResp uploadErrorResponse + + if err := json.NewDecoder(resp.RawBody()).Decode(&errResp); err != nil { + return &UploadWalSegmentResult{IsGapDetected: true}, nil + } + + return &UploadWalSegmentResult{ + IsGapDetected: true, + ExpectedSegmentName: errResp.ExpectedSegmentName, + ReceivedSegmentName: errResp.ReceivedSegmentName, + }, nil + + default: + respBody, _ := io.ReadAll(resp.RawBody()) + + return nil, fmt.Errorf("upload failed with status %d: %s", resp.StatusCode(), string(respBody)) + } +} + +func (c *Client) FetchServerVersion(ctx context.Context) (string, error) { + var ver versionResponse + + _, err := c.json.R(). + SetContext(ctx). + SetResult(&ver). + Get(c.buildURL(versionPath)) + if err != nil { + return "", err + } + + return ver.Version, nil +} + +func (c *Client) DownloadAgentBinary(ctx context.Context, arch, destPath string) error { + resp, err := c.stream.R(). + SetContext(ctx). + SetQueryParam("arch", arch). + SetDoNotParseResponse(true). + Get(c.buildURL(agentBinaryPath)) + if err != nil { + return err + } + defer func() { _ = resp.RawBody().Close() }() + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("server returned %d for agent download", resp.StatusCode()) + } + + f, err := os.Create(destPath) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + _, err = io.Copy(f, resp.RawBody()) + + return err +} + +func (c *Client) buildURL(path string) string { + return c.host + path +} diff --git a/agent/internal/features/api/dto.go b/agent/internal/features/api/dto.go new file mode 100644 index 0000000..2a729cd --- /dev/null +++ b/agent/internal/features/api/dto.go @@ -0,0 +1,33 @@ +package api + +import "time" + +type WalChainValidityResponse struct { + IsValid bool `json:"isValid"` + Error string `json:"error,omitempty"` + LastContiguousSegment string `json:"lastContiguousSegment,omitempty"` +} + +type NextFullBackupTimeResponse struct { + NextFullBackupTime *time.Time `json:"nextFullBackupTime"` +} + +type UploadWalSegmentResult struct { + IsGapDetected bool + ExpectedSegmentName string + ReceivedSegmentName string +} + +type reportErrorRequest struct { + Error string `json:"error"` +} + +type versionResponse struct { + Version string `json:"version"` +} + +type uploadErrorResponse struct { + Error string `json:"error"` + ExpectedSegmentName string `json:"expectedSegmentName"` + ReceivedSegmentName string `json:"receivedSegmentName"` +} diff --git a/agent/internal/features/full_backup/backuper.go b/agent/internal/features/full_backup/backuper.go new file mode 100644 index 0000000..8bd78ea --- /dev/null +++ b/agent/internal/features/full_backup/backuper.go @@ -0,0 +1,278 @@ +package full_backup + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "path/filepath" + "sync/atomic" + "time" + + "github.com/klauspost/compress/zstd" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" +) + +const ( + checkInterval = 30 * time.Second + retryDelay = 1 * time.Minute + uploadTimeout = 30 * time.Minute +) + +var retryDelayOverride *time.Duration + +type CmdBuilder func(ctx context.Context) *exec.Cmd + +// FullBackuper runs pg_basebackup when the WAL chain is broken or a scheduled backup is due. +// +// Every 30 seconds it checks two conditions via the Databasus API: +// 1. WAL chain validity — if broken or no full backup exists, triggers an immediate basebackup. +// 2. Scheduled backup time — if the next full backup time has passed, triggers a basebackup. +// +// Only one basebackup runs at a time (guarded by atomic bool). +// On failure the error is reported to the server and the backup retries after 1 minute, indefinitely. +// WAL segment uploads (handled by wal.Streamer) continue independently and are not paused. +// +// pg_basebackup runs as "pg_basebackup -Ft -D - -X none --verbose". Stdout (tar) is zstd-compressed +// and uploaded to the server. Stderr is parsed for WAL start/stop segment names (LSN → segment arithmetic). +type FullBackuper struct { + cfg *config.Config + apiClient *api.Client + log *slog.Logger + isRunning atomic.Bool + cmdBuilder CmdBuilder +} + +func NewFullBackuper(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *FullBackuper { + backuper := &FullBackuper{ + cfg: cfg, + apiClient: apiClient, + log: log, + } + + backuper.cmdBuilder = backuper.defaultCmdBuilder + + return backuper +} + +func (backuper *FullBackuper) Run(ctx context.Context) { + backuper.log.Info("Full backuper started") + + backuper.checkAndRunIfNeeded(ctx) + + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + backuper.log.Info("Full backuper stopping") + return + case <-ticker.C: + backuper.checkAndRunIfNeeded(ctx) + } + } +} + +func (backuper *FullBackuper) checkAndRunIfNeeded(ctx context.Context) { + if backuper.isRunning.Load() { + backuper.log.Debug("Skipping check: basebackup already in progress") + return + } + + chainResp, err := backuper.apiClient.CheckWalChainValidity(ctx) + if err != nil { + backuper.log.Error("Failed to check WAL chain validity", "error", err) + return + } + + if !chainResp.IsValid { + backuper.log.Info("WAL chain is invalid, triggering basebackup", + "error", chainResp.Error, + "lastContiguousSegment", chainResp.LastContiguousSegment, + ) + + backuper.runBasebackupWithRetry(ctx) + + return + } + + nextTimeResp, err := backuper.apiClient.GetNextFullBackupTime(ctx) + if err != nil { + backuper.log.Error("Failed to check next full backup time", "error", err) + return + } + + if nextTimeResp.NextFullBackupTime == nil || !nextTimeResp.NextFullBackupTime.After(time.Now().UTC()) { + backuper.log.Info("Scheduled full backup is due, triggering basebackup") + backuper.runBasebackupWithRetry(ctx) + + return + } + + backuper.log.Debug("No basebackup needed", + "nextFullBackupTime", nextTimeResp.NextFullBackupTime, + ) +} + +func (backuper *FullBackuper) runBasebackupWithRetry(ctx context.Context) { + if !backuper.isRunning.CompareAndSwap(false, true) { + backuper.log.Debug("Skipping basebackup: already running") + return + } + defer backuper.isRunning.Store(false) + + for { + if ctx.Err() != nil { + return + } + + backuper.log.Info("Starting pg_basebackup") + + err := backuper.executeAndUploadBasebackup(ctx) + if err == nil { + backuper.log.Info("Basebackup completed successfully") + return + } + + backuper.log.Error("Basebackup failed", "error", err) + backuper.reportError(ctx, err.Error()) + + delay := retryDelay + if retryDelayOverride != nil { + delay = *retryDelayOverride + } + + backuper.log.Info("Retrying basebackup after delay", "delay", delay) + + select { + case <-ctx.Done(): + return + case <-time.After(delay): + } + } +} + +func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) error { + cmd := backuper.cmdBuilder(ctx) + + var stderrBuf bytes.Buffer + cmd.Stderr = &stderrBuf + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("create stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start pg_basebackup: %w", err) + } + + var compressedBuf bytes.Buffer + compressErr := backuper.compressToBuffer(&compressedBuf, stdoutPipe) + + cmdErr := cmd.Wait() + if cmdErr != nil { + return fmt.Errorf("pg_basebackup exited with error: %w (stderr: %s)", cmdErr, stderrBuf.String()) + } + + if compressErr != nil { + return fmt.Errorf("compress basebackup: %w", compressErr) + } + + stderrStr := stderrBuf.String() + backuper.log.Debug("pg_basebackup stderr", "stderr", stderrStr) + + startSegment, stopSegment, err := ParseBasebackupStderr(stderrStr) + if err != nil { + return fmt.Errorf("parse pg_basebackup stderr: %w", err) + } + + backuper.log.Info("Basebackup WAL segments parsed", + "startSegment", startSegment, + "stopSegment", stopSegment, + "compressedSize", compressedBuf.Len(), + ) + + uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) + defer cancel() + + if err := backuper.apiClient.UploadBasebackup(uploadCtx, startSegment, stopSegment, &compressedBuf); err != nil { + return fmt.Errorf("upload basebackup: %w", err) + } + + return nil +} + +func (backuper *FullBackuper) compressToBuffer(dst *bytes.Buffer, reader io.Reader) error { + encoder, err := zstd.NewWriter(dst, + zstd.WithEncoderLevel(zstd.SpeedDefault), + zstd.WithEncoderCRC(true), + ) + if err != nil { + return fmt.Errorf("create zstd encoder: %w", err) + } + + if _, err := io.Copy(encoder, reader); err != nil { + _ = encoder.Close() + return fmt.Errorf("compress: %w", err) + } + + if err := encoder.Close(); err != nil { + return fmt.Errorf("close encoder: %w", err) + } + + return nil +} + +func (backuper *FullBackuper) reportError(ctx context.Context, errMsg string) { + if err := backuper.apiClient.ReportBackupError(ctx, errMsg); err != nil { + backuper.log.Error("Failed to report error to server", "error", err) + } +} + +func (backuper *FullBackuper) defaultCmdBuilder(ctx context.Context) *exec.Cmd { + switch backuper.cfg.PgType { + case "docker": + return backuper.buildDockerCmd(ctx) + default: + return backuper.buildHostCmd(ctx) + } +} + +func (backuper *FullBackuper) buildHostCmd(ctx context.Context) *exec.Cmd { + binary := "pg_basebackup" + if backuper.cfg.PgHostBinDir != "" { + binary = filepath.Join(backuper.cfg.PgHostBinDir, "pg_basebackup") + } + + cmd := exec.CommandContext(ctx, binary, + "-Ft", "-D", "-", "-X", "none", "--verbose", + "-h", backuper.cfg.PgHost, + "-p", fmt.Sprintf("%d", backuper.cfg.PgPort), + "-U", backuper.cfg.PgUser, + ) + + cmd.Env = append(os.Environ(), "PGPASSWORD="+backuper.cfg.PgPassword) + + return cmd +} + +func (backuper *FullBackuper) buildDockerCmd(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, "docker", "exec", + "-e", "PGPASSWORD="+backuper.cfg.PgPassword, + "-i", backuper.cfg.PgDockerContainerName, + "pg_basebackup", + "-Ft", "-D", "-", "-X", "none", "--verbose", + "-h", backuper.cfg.PgHost, + "-p", fmt.Sprintf("%d", backuper.cfg.PgPort), + "-U", backuper.cfg.PgUser, + ) + + return cmd +} diff --git a/agent/internal/features/full_backup/backuper_test.go b/agent/internal/features/full_backup/backuper_test.go new file mode 100644 index 0000000..1c8350f --- /dev/null +++ b/agent/internal/features/full_backup/backuper_test.go @@ -0,0 +1,604 @@ +package full_backup + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "databasus-agent/internal/config" + "databasus-agent/internal/features/api" + "databasus-agent/internal/logger" +) + +const ( + testChainValidPath = "/api/v1/backups/postgres/wal/is-wal-chain-valid-since-last-full-backup" + testNextBackupTimePath = "/api/v1/backups/postgres/wal/next-full-backup-time" + testUploadPath = "/api/v1/backups/postgres/wal/upload" + testReportErrorPath = "/api/v1/backups/postgres/wal/error" +) + +func Test_RunFullBackup_WhenChainBroken_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var uploadReceived bool + var uploadHeaders http.Header + var uploadQuery string + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "wal_chain_broken", + LastContiguousSegment: "000000010000000100000011", + }) + case testUploadPath: + mu.Lock() + uploadReceived = true + uploadHeaders = r.Header.Clone() + uploadQuery = r.URL.RawQuery + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "test-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, uploadReceived) + assert.Equal(t, "basebackup", uploadHeaders.Get("X-Upload-Type")) + assert.Equal(t, "application/octet-stream", uploadHeaders.Get("Content-Type")) + assert.Equal(t, "test-token", uploadHeaders.Get("Authorization")) + assert.Contains(t, uploadQuery, "fullBackupWalStartSegment=000000010000000000000002") + assert.Contains(t, uploadQuery, "fullBackupWalStopSegment=000000010000000000000002") +} + +func Test_RunFullBackup_WhenScheduledBackupDue_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var uploadReceived bool + + pastTime := time.Now().UTC().Add(-1 * time.Hour) + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &pastTime}) + case testUploadPath: + mu.Lock() + uploadReceived = true + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "scheduled-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, uploadReceived) +} + +func Test_RunFullBackup_WhenNoFullBackupExists_ImmediateBasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var uploadReceived bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + mu.Lock() + uploadReceived = true + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "first-backup-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, uploadReceived) +} + +func Test_RunFullBackup_WhenUploadFails_RetriesAfterDelay(t *testing.T) { + var mu sync.Mutex + var uploadAttempts int + var errorReported bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + _, _ = io.ReadAll(r.Body) + + mu.Lock() + uploadAttempts++ + attempt := uploadAttempts + mu.Unlock() + + if attempt == 1 { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"storage unavailable"}`)) + return + } + + w.WriteHeader(http.StatusNoContent) + case testReportErrorPath: + mu.Lock() + errorReported = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "retry-backup-data", validStderr()) + + origRetryDelay := retryDelay + setRetryDelay(100 * time.Millisecond) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadAttempts >= 2 + }, 10*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.GreaterOrEqual(t, uploadAttempts, 2) + assert.True(t, errorReported) +} + +func Test_RunFullBackup_WhenAlreadyRunning_SkipsExecution(t *testing.T) { + var mu sync.Mutex + var uploadCount int + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + _, _ = io.ReadAll(r.Body) + + mu.Lock() + uploadCount++ + mu.Unlock() + + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + fb.isRunning.Store(true) + + fb.checkAndRunIfNeeded(context.Background()) + + mu.Lock() + count := uploadCount + mu.Unlock() + + assert.Equal(t, 0, count, "should not trigger backup when already running") +} + +func Test_RunFullBackup_WhenContextCancelled_StopsCleanly(t *testing.T) { + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusInternalServerError) + case testReportErrorPath: + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + origRetryDelay := retryDelay + setRetryDelay(5 * time.Second) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + defer cancel() + + done := make(chan struct{}) + go func() { + fb.Run(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("Run should have stopped after context cancellation") + } +} + +func Test_RunFullBackup_WhenChainValidAndNotScheduled_NoBasebackupTriggered(t *testing.T) { + var uploadReceived atomic.Bool + + futureTime := time.Now().UTC().Add(24 * time.Hour) + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: &futureTime}) + case testUploadPath: + uploadReceived.Store(true) + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go fb.Run(ctx) + time.Sleep(500 * time.Millisecond) + cancel() + + assert.False(t, uploadReceived.Load(), "should not trigger backup when chain valid and not scheduled") +} + +func Test_RunFullBackup_WhenStderrParsingFails_ReportsErrorAndRetries(t *testing.T) { + var mu sync.Mutex + var errorReported bool + var uploadAttempts int + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + mu.Lock() + uploadAttempts++ + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + case testReportErrorPath: + mu.Lock() + errorReported = true + mu.Unlock() + + w.WriteHeader(http.StatusOK) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "data", "pg_basebackup: unexpected output with no LSN info") + + origRetryDelay := retryDelay + setRetryDelay(100 * time.Millisecond) + defer setRetryDelay(origRetryDelay) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return errorReported + }, 2*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, errorReported) + assert.Equal(t, 0, uploadAttempts, "should not upload when stderr parsing fails") +} + +func Test_RunFullBackup_WhenNextBackupTimeNull_BasebackupTriggered(t *testing.T) { + var mu sync.Mutex + var uploadReceived bool + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{IsValid: true}) + case testNextBackupTimePath: + writeJSON(w, api.NextFullBackupTimeResponse{NextFullBackupTime: nil}) + case testUploadPath: + mu.Lock() + uploadReceived = true + mu.Unlock() + + _, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, "first-run-data", validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return uploadReceived + }, 5*time.Second) + cancel() + + mu.Lock() + defer mu.Unlock() + + assert.True(t, uploadReceived) +} + +func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) { + var mu sync.Mutex + var receivedBody []byte + + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testChainValidPath: + writeJSON(w, api.WalChainValidityResponse{ + IsValid: false, + Error: "no_full_backup", + }) + case testUploadPath: + body, _ := io.ReadAll(r.Body) + + mu.Lock() + receivedBody = body + mu.Unlock() + + w.WriteHeader(http.StatusNoContent) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + originalContent := "test-backup-content-for-compression-check" + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = mockCmdBuilder(t, originalContent, validStderr()) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + go fb.Run(ctx) + waitForCondition(t, func() bool { + mu.Lock() + defer mu.Unlock() + return len(receivedBody) > 0 + }, 5*time.Second) + cancel() + + mu.Lock() + body := receivedBody + mu.Unlock() + + decoder, err := zstd.NewReader(nil) + require.NoError(t, err) + defer decoder.Close() + + decompressed, err := decoder.DecodeAll(body, nil) + require.NoError(t, err) + assert.Equal(t, originalContent, string(decompressed)) +} + +func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + + return server +} + +func newTestFullBackuper(serverURL string) *FullBackuper { + cfg := &config.Config{ + DatabasusHost: serverURL, + DbID: "test-db-id", + Token: "test-token", + PgHost: "localhost", + PgPort: 5432, + PgUser: "postgres", + PgPassword: "password", + PgType: "host", + } + + apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger()) + + return NewFullBackuper(cfg, apiClient, logger.GetLogger()) +} + +func mockCmdBuilder(t *testing.T, stdoutContent, stderrContent string) CmdBuilder { + t.Helper() + + return func(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, os.Args[0], + "-test.run=TestHelperProcess", + "--", + stdoutContent, + stderrContent, + ) + + cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS=1") + + return cmd + } +} + +func TestHelperProcess(t *testing.T) { + if os.Getenv("GO_TEST_HELPER_PROCESS") != "1" { + return + } + + args := os.Args + for i, arg := range args { + if arg == "--" { + args = args[i+1:] + break + } + } + + if len(args) >= 1 { + _, _ = fmt.Fprint(os.Stdout, args[0]) + } + + if len(args) >= 2 { + _, _ = fmt.Fprint(os.Stderr, args[1]) + } + + os.Exit(0) +} + +func validStderr() string { + return `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1 +pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: base backup completed` +} + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + + if err := json.NewEncoder(w).Encode(v); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } +} + +func waitForCondition(t *testing.T, condition func() bool, timeout time.Duration) { + t.Helper() + + deadline := time.Now().Add(timeout) + + for time.Now().Before(deadline) { + if condition() { + return + } + + time.Sleep(50 * time.Millisecond) + } + + t.Fatalf("condition not met within %v", timeout) +} + +func setRetryDelay(d time.Duration) { + retryDelayOverride = &d +} + +func init() { + retryDelayOverride = nil +} diff --git a/agent/internal/features/full_backup/stderr_parser.go b/agent/internal/features/full_backup/stderr_parser.go new file mode 100644 index 0000000..7c6069c --- /dev/null +++ b/agent/internal/features/full_backup/stderr_parser.go @@ -0,0 +1,75 @@ +package full_backup + +import ( + "fmt" + "regexp" + "strconv" + "strings" +) + +const defaultWalSegmentSize uint32 = 16 * 1024 * 1024 // 16 MB + +var ( + startLSNRegex = regexp.MustCompile(`checkpoint redo point at ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`) + stopLSNRegex = regexp.MustCompile(`write-ahead log end point: ([0-9A-Fa-f]+/[0-9A-Fa-f]+)`) +) + +func ParseBasebackupStderr(stderr string) (startSegment, stopSegment string, err error) { + startMatch := startLSNRegex.FindStringSubmatch(stderr) + if len(startMatch) < 2 { + return "", "", fmt.Errorf("failed to parse start WAL location from pg_basebackup stderr") + } + + stopMatch := stopLSNRegex.FindStringSubmatch(stderr) + if len(stopMatch) < 2 { + return "", "", fmt.Errorf("failed to parse stop WAL location from pg_basebackup stderr") + } + + startSegment, err = LSNToSegmentName(startMatch[1], 1, defaultWalSegmentSize) + if err != nil { + return "", "", fmt.Errorf("failed to convert start LSN to segment name: %w", err) + } + + stopSegment, err = LSNToSegmentName(stopMatch[1], 1, defaultWalSegmentSize) + if err != nil { + return "", "", fmt.Errorf("failed to convert stop LSN to segment name: %w", err) + } + + return startSegment, stopSegment, nil +} + +func LSNToSegmentName(lsn string, timelineID, walSegmentSize uint32) (string, error) { + high, low, err := parseLSN(lsn) + if err != nil { + return "", err + } + + segmentsPerXLogID := uint32(0x100000000 / uint64(walSegmentSize)) + logID := high + segmentOffset := low / walSegmentSize + + if segmentOffset >= segmentsPerXLogID { + return "", fmt.Errorf("segment offset %d exceeds segments per XLogId %d", segmentOffset, segmentsPerXLogID) + } + + return fmt.Sprintf("%08X%08X%08X", timelineID, logID, segmentOffset), nil +} + +func parseLSN(lsn string) (high, low uint32, err error) { + parts := strings.SplitN(lsn, "/", 2) + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid LSN format: %q (expected X/Y)", lsn) + } + + highVal, err := strconv.ParseUint(parts[0], 16, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid LSN high part %q: %w", parts[0], err) + } + + lowVal, err := strconv.ParseUint(parts[1], 16, 32) + if err != nil { + return 0, 0, fmt.Errorf("invalid LSN low part %q: %w", parts[1], err) + } + + return uint32(highVal), uint32(lowVal), nil +} diff --git a/agent/internal/features/full_backup/stderr_parser_test.go b/agent/internal/features/full_backup/stderr_parser_test.go new file mode 100644 index 0000000..cd833fc --- /dev/null +++ b/agent/internal/features/full_backup/stderr_parser_test.go @@ -0,0 +1,162 @@ +package full_backup + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseBasebackupStderr_WithPG17Output_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 0/2000028, on timeline 1 +pg_basebackup: starting background WAL receiver +pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: waiting for background process to finish streaming ... +pg_basebackup: syncing data to disk ... +pg_basebackup: renaming backup_manifest.tmp to backup_manifest +pg_basebackup: base backup completed` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "000000010000000000000002", startSeg) + assert.Equal(t, "000000010000000000000002", stopSeg) +} + +func Test_ParseBasebackupStderr_WithPG15Output_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: initiating base backup, waiting for checkpoint to complete +pg_basebackup: checkpoint completed +pg_basebackup: write-ahead log start point: 1/AB000028, on timeline 1 +pg_basebackup: checkpoint redo point at 1/AB000028 +pg_basebackup: write-ahead log end point: 1/AC000000 +pg_basebackup: base backup completed` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "0000000100000001000000AB", startSeg) + assert.Equal(t, "0000000100000001000000AC", stopSeg) +} + +func Test_ParseBasebackupStderr_WithHighLogID_ExtractsCorrectSegments(t *testing.T) { + stderr := `pg_basebackup: checkpoint redo point at A/FF000028 +pg_basebackup: write-ahead log end point: B/1000000` + + startSeg, stopSeg, err := ParseBasebackupStderr(stderr) + + require.NoError(t, err) + assert.Equal(t, "000000010000000A000000FF", startSeg) + assert.Equal(t, "000000010000000B00000001", stopSeg) +} + +func Test_ParseBasebackupStderr_WhenStartLSNMissing_ReturnsError(t *testing.T) { + stderr := `pg_basebackup: write-ahead log end point: 0/2000100 +pg_basebackup: base backup completed` + + _, _, err := ParseBasebackupStderr(stderr) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse start WAL location") +} + +func Test_ParseBasebackupStderr_WhenStopLSNMissing_ReturnsError(t *testing.T) { + stderr := `pg_basebackup: checkpoint redo point at 0/2000028 +pg_basebackup: base backup completed` + + _, _, err := ParseBasebackupStderr(stderr) + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse stop WAL location") +} + +func Test_ParseBasebackupStderr_WhenEmptyStderr_ReturnsError(t *testing.T) { + _, _, err := ParseBasebackupStderr("") + + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse start WAL location") +} + +func Test_LSNToSegmentName_WithBoundaryValues_ConvertsCorrectly(t *testing.T) { + tests := []struct { + name string + lsn string + timeline uint32 + segSize uint32 + expected string + }{ + { + name: "first segment", + lsn: "0/1000000", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000001", + }, + { + name: "segment at boundary FF", + lsn: "0/FF000000", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "0000000100000000000000FF", + }, + { + name: "segment in second log file", + lsn: "1/0", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000100000000", + }, + { + name: "segment with offset within 16MB", + lsn: "0/200ABCD", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000002", + }, + { + name: "zero LSN", + lsn: "0/0", + timeline: 1, + segSize: 16 * 1024 * 1024, + expected: "000000010000000000000000", + }, + { + name: "high timeline ID", + lsn: "0/1000000", + timeline: 2, + segSize: 16 * 1024 * 1024, + expected: "000000020000000000000001", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := LSNToSegmentName(tt.lsn, tt.timeline, tt.segSize) + + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func Test_LSNToSegmentName_WithInvalidLSN_ReturnsError(t *testing.T) { + tests := []struct { + name string + lsn string + }{ + {name: "no slash", lsn: "012345"}, + {name: "empty string", lsn: ""}, + {name: "invalid hex high", lsn: "GG/0"}, + {name: "invalid hex low", lsn: "0/ZZ"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := LSNToSegmentName(tt.lsn, 1, 16*1024*1024) + + require.Error(t, err) + }) + } +} diff --git a/agent/internal/features/start/daemon.go b/agent/internal/features/start/daemon.go new file mode 100644 index 0000000..88e865b --- /dev/null +++ b/agent/internal/features/start/daemon.go @@ -0,0 +1,121 @@ +//go:build !windows + +package start + +import ( + "errors" + "fmt" + "log/slog" + "os" + "os/exec" + "syscall" + "time" +) + +const ( + logFileName = "databasus.log" + stopTimeout = 30 * time.Second + stopPollInterval = 500 * time.Millisecond + daemonStartupDelay = 500 * time.Millisecond +) + +func Stop(log *slog.Logger) error { + pid, err := ReadLockFilePID() + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return errors.New("agent is not running (no lock file found)") + } + + return fmt.Errorf("failed to read lock file: %w", err) + } + + if !isProcessAlive(pid) { + _ = os.Remove(lockFileName) + return fmt.Errorf("agent is not running (stale lock file removed, PID %d)", pid) + } + + log.Info("Sending SIGTERM to agent", "pid", pid) + + if err := syscall.Kill(pid, syscall.SIGTERM); err != nil { + return fmt.Errorf("failed to send SIGTERM to PID %d: %w", pid, err) + } + + deadline := time.Now().Add(stopTimeout) + for time.Now().Before(deadline) { + if !isProcessAlive(pid) { + log.Info("Agent stopped", "pid", pid) + return nil + } + + time.Sleep(stopPollInterval) + } + + return fmt.Errorf("agent (PID %d) did not stop within %s — process may be stuck", pid, stopTimeout) +} + +func Status(log *slog.Logger) error { + pid, err := ReadLockFilePID() + if err != nil { + if errors.Is(err, os.ErrNotExist) { + fmt.Println("Agent is not running") + return nil + } + + return fmt.Errorf("failed to read lock file: %w", err) + } + + if isProcessAlive(pid) { + fmt.Printf("Agent is running (PID %d)\n", pid) + } else { + fmt.Println("Agent is not running (stale lock file)") + _ = os.Remove(lockFileName) + } + + return nil +} + +func spawnDaemon(log *slog.Logger) (int, error) { + execPath, err := os.Executable() + if err != nil { + return 0, fmt.Errorf("failed to resolve executable path: %w", err) + } + + args := []string{"_run"} + + logFile, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return 0, fmt.Errorf("failed to open log file %s: %w", logFileName, err) + } + + cwd, err := os.Getwd() + if err != nil { + _ = logFile.Close() + return 0, fmt.Errorf("failed to get working directory: %w", err) + } + + cmd := exec.Command(execPath, args...) + cmd.Dir = cwd + cmd.Stdout = logFile + cmd.Stderr = logFile + cmd.SysProcAttr = &syscall.SysProcAttr{Setsid: true} + + if err := cmd.Start(); err != nil { + _ = logFile.Close() + return 0, fmt.Errorf("failed to start daemon process: %w", err) + } + + pid := cmd.Process.Pid + + // Detach — we don't wait for the child + _ = logFile.Close() + + time.Sleep(daemonStartupDelay) + + if !isProcessAlive(pid) { + return 0, fmt.Errorf("daemon process (PID %d) exited immediately — check %s for details", pid, logFileName) + } + + log.Info("Daemon spawned", "pid", pid, "log", logFileName) + + return pid, nil +} diff --git a/agent/internal/features/start/daemon_windows.go b/agent/internal/features/start/daemon_windows.go new file mode 100644 index 0000000..8743554 --- /dev/null +++ b/agent/internal/features/start/daemon_windows.go @@ -0,0 +1,20 @@ +//go:build windows + +package start + +import ( + "errors" + "log/slog" +) + +func Stop(log *slog.Logger) error { + return errors.New("stop is not supported on Windows — use Ctrl+C in the terminal where the agent is running") +} + +func Status(log *slog.Logger) error { + return errors.New("status is not supported on Windows — check the terminal where the agent is running") +} + +func spawnDaemon(_ *slog.Logger) (int, error) { + return 0, errors.New("daemon mode is not supported on Windows") +} diff --git a/agent/internal/features/start/lock.go b/agent/internal/features/start/lock.go index 0ce5180..58a0d10 100644 --- a/agent/internal/features/start/lock.go +++ b/agent/internal/features/start/lock.go @@ -54,6 +54,16 @@ func ReleaseLock(f *os.File) { _ = os.Remove(lockFileName) } +func ReadLockFilePID() (int, error) { + f, err := os.Open(lockFileName) + if err != nil { + return 0, err + } + defer f.Close() + + return readLockPID(f) +} + func writePID(f *os.File) error { if err := f.Truncate(0); err != nil { return fmt.Errorf("failed to truncate lock file: %w", err) diff --git a/agent/internal/features/start/start.go b/agent/internal/features/start/start.go index 1f32275..a418829 100644 --- a/agent/internal/features/start/start.go +++ b/agent/internal/features/start/start.go @@ -9,6 +9,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "runtime" "strings" "syscall" "time" @@ -16,6 +17,8 @@ import ( "github.com/jackc/pgx/v5" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" + full_backup "databasus-agent/internal/features/full_backup" "databasus-agent/internal/features/wal" ) @@ -24,13 +27,7 @@ const ( dbVerifyTimeout = 10 * time.Second ) -func Run(cfg *config.Config, log *slog.Logger) error { - lockFile, err := AcquireLock(log) - if err != nil { - return err - } - defer ReleaseLock(lockFile) - +func Start(cfg *config.Config, log *slog.Logger) error { if err := validateConfig(cfg); err != nil { return err } @@ -43,10 +40,36 @@ func Run(cfg *config.Config, log *slog.Logger) error { return err } + if runtime.GOOS == "windows" { + return RunDaemon(cfg, log) + } + + pid, err := spawnDaemon(log) + if err != nil { + return err + } + + fmt.Printf("Agent started in background (PID %d)\n", pid) + + return nil +} + +func RunDaemon(cfg *config.Config, log *slog.Logger) error { + lockFile, err := AcquireLock(log) + if err != nil { + return err + } + defer ReleaseLock(lockFile) + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer cancel() - streamer := wal.NewStreamer(cfg, log) + apiClient := api.NewClient(cfg.DatabasusHost, cfg.Token, log) + + fullBackuper := full_backup.NewFullBackuper(cfg, apiClient, log) + go fullBackuper.Run(ctx) + + streamer := wal.NewStreamer(cfg, apiClient, log) streamer.Run(ctx) log.Info("Agent stopped") diff --git a/agent/internal/features/upgrade/dto.go b/agent/internal/features/upgrade/dto.go deleted file mode 100644 index 11f251b..0000000 --- a/agent/internal/features/upgrade/dto.go +++ /dev/null @@ -1,5 +0,0 @@ -package upgrade - -type versionResponse struct { - Version string `json:"version"` -} diff --git a/agent/internal/features/upgrade/upgrader.go b/agent/internal/features/upgrade/upgrader.go index a75ae2d..cf01ce9 100644 --- a/agent/internal/features/upgrade/upgrader.go +++ b/agent/internal/features/upgrade/upgrader.go @@ -2,30 +2,33 @@ package upgrade import ( "context" - "encoding/json" "fmt" - "io" "log/slog" - "net/http" "os" "os/exec" "runtime" "strings" "syscall" - "time" + + "databasus-agent/internal/features/api" ) -func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog.Logger) error { +// CheckAndUpdate ensures the agent binary matches the server's expected version. +// It fetches the server version, downloads the new binary if different, verifies it, +// replaces the current executable, and re-execs the process with the same arguments. +// Skipped in development mode. Runs once on startup before the main agent loop. +func CheckAndUpdate(apiClient *api.Client, currentVersion string, isDev bool, log *slog.Logger) error { if isDev { log.Info("Skipping update check (development mode)") return nil } - serverVersion, err := fetchServerVersion(databasusHost, log) + serverVersion, err := apiClient.FetchServerVersion(context.Background()) if err != nil { + log.Warn("Could not reach server for update check, continuing", "error", err) + return fmt.Errorf( - "unable to check version, please verify Databasus server is available at %s: %w", - databasusHost, + "unable to check version, please verify Databasus server is available: %w", err, ) } @@ -48,7 +51,7 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog. _ = os.Remove(tempPath) }() - if err := downloadBinary(databasusHost, tempPath); err != nil { + if err := apiClient.DownloadAgentBinary(context.Background(), runtime.GOARCH, tempPath); err != nil { return fmt.Errorf("failed to download update: %w", err) } @@ -69,74 +72,6 @@ func CheckAndUpdate(databasusHost, currentVersion string, isDev bool, log *slog. return syscall.Exec(selfPath, os.Args, os.Environ()) } -func fetchServerVersion(host string, log *slog.Logger) (string, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - client := &http.Client{Timeout: 10 * time.Second} - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, host+"/api/v1/system/version", nil) - if err != nil { - return "", err - } - - resp, err := client.Do(req) - if err != nil { - log.Warn("Could not reach server for update check, continuing", "error", err) - return "", err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - log.Warn( - "Server returned non-OK status for version check, continuing", - "status", - resp.StatusCode, - ) - return "", fmt.Errorf("status %d", resp.StatusCode) - } - - var ver versionResponse - if err := json.NewDecoder(resp.Body).Decode(&ver); err != nil { - log.Warn("Failed to parse server version response, continuing", "error", err) - return "", err - } - - return ver.Version, nil -} - -func downloadBinary(host, destPath string) error { - url := fmt.Sprintf("%s/api/v1/system/agent?arch=%s", host, runtime.GOARCH) - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) - defer cancel() - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return err - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer func() { _ = resp.Body.Close() }() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("server returned %d for agent download", resp.StatusCode) - } - - f, err := os.Create(destPath) - if err != nil { - return err - } - defer func() { _ = f.Close() }() - - _, err = io.Copy(f, resp.Body) - - return err -} - func verifyBinary(binaryPath, expectedVersion string) error { cmd := exec.CommandContext(context.Background(), binaryPath, "version") diff --git a/agent/internal/features/wal/streamer.go b/agent/internal/features/wal/streamer.go index 105d760..222b571 100644 --- a/agent/internal/features/wal/streamer.go +++ b/agent/internal/features/wal/streamer.go @@ -2,11 +2,9 @@ package wal import ( "context" - "encoding/json" "fmt" "io" "log/slog" - "net/http" "os" "path/filepath" "regexp" @@ -17,33 +15,27 @@ import ( "github.com/klauspost/compress/zstd" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" ) const ( pollInterval = 2 * time.Second uploadTimeout = 5 * time.Minute - uploadPath = "/api/v1/backups/postgres/wal/upload" ) var segmentNameRegex = regexp.MustCompile(`^[0-9A-Fa-f]{24}$`) type Streamer struct { - cfg *config.Config - httpClient *http.Client - log *slog.Logger + cfg *config.Config + apiClient *api.Client + log *slog.Logger } -type uploadErrorResponse struct { - Error string `json:"error"` - ExpectedSegmentName string `json:"expectedSegmentName"` - ReceivedSegmentName string `json:"receivedSegmentName"` -} - -func NewStreamer(cfg *config.Config, log *slog.Logger) *Streamer { +func NewStreamer(cfg *config.Config, apiClient *api.Client, log *slog.Logger) *Streamer { return &Streamer{ - cfg: cfg, - httpClient: &http.Client{}, - log: log, + cfg: cfg, + apiClient: apiClient, + log: log, } } @@ -129,58 +121,33 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) defer cancel() - req, err := http.NewRequestWithContext(uploadCtx, http.MethodPost, s.buildUploadURL(), pr) + result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr) if err != nil { - _ = pr.Close() - return fmt.Errorf("create request: %w", err) + return err } - req.Header.Set("Authorization", s.cfg.Token) - req.Header.Set("Content-Type", "application/octet-stream") - req.Header.Set("X-Upload-Type", "wal") - req.Header.Set("X-Wal-Segment-Name", segmentName) - - resp, err := s.httpClient.Do(req) - if err != nil { - return fmt.Errorf("upload request: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - switch resp.StatusCode { - case http.StatusNoContent: - s.log.Debug("WAL segment uploaded", "segment", segmentName) - - if *s.cfg.IsDeleteWalAfterUpload { - if err := os.Remove(filePath); err != nil { - s.log.Warn("Failed to delete uploaded WAL segment", - "segment", segmentName, - "error", err, - ) - } - } - - return nil - - case http.StatusConflict: - var errResp uploadErrorResponse - - if err := json.NewDecoder(resp.Body).Decode(&errResp); err == nil { - s.log.Warn("WAL chain gap detected", - "segment", segmentName, - "expected", errResp.ExpectedSegmentName, - "received", errResp.ReceivedSegmentName, - ) - } else { - s.log.Warn("WAL chain gap detected", "segment", segmentName) - } + if result.IsGapDetected { + s.log.Warn("WAL chain gap detected", + "segment", segmentName, + "expected", result.ExpectedSegmentName, + "received", result.ReceivedSegmentName, + ) return fmt.Errorf("gap detected for segment %s", segmentName) - - default: - body, _ := io.ReadAll(resp.Body) - - return fmt.Errorf("upload failed with status %d: %s", resp.StatusCode, string(body)) } + + s.log.Debug("WAL segment uploaded", "segment", segmentName) + + if *s.cfg.IsDeleteWalAfterUpload { + if err := os.Remove(filePath); err != nil { + s.log.Warn("Failed to delete uploaded WAL segment", + "segment", segmentName, + "error", err, + ) + } + } + + return nil } func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) { @@ -213,7 +180,3 @@ func (s *Streamer) compressAndStream(pw *io.PipeWriter, filePath string) { _ = pw.Close() } - -func (s *Streamer) buildUploadURL() string { - return s.cfg.DatabasusHost + uploadPath -} diff --git a/agent/internal/features/wal/streamer_test.go b/agent/internal/features/wal/streamer_test.go index 073a633..aa5e2af 100644 --- a/agent/internal/features/wal/streamer_test.go +++ b/agent/internal/features/wal/streamer_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "databasus-agent/internal/config" + "databasus-agent/internal/features/api" "databasus-agent/internal/logger" ) @@ -39,8 +40,7 @@ func Test_UploadSegment_SingleSegment_ServerReceivesCorrectHeadersAndBody(t *tes })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -78,8 +78,7 @@ func Test_UploadSegments_MultipleSegmentsOutOfOrder_UploadedInAscendingOrder(t * })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -115,8 +114,7 @@ func Test_UploadSegments_DirectoryHasTmpFiles_TmpFilesIgnored(t *testing.T) { })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -146,7 +144,8 @@ func Test_UploadSegment_DeleteEnabled_FileRemovedAfterUpload(t *testing.T) { isDeleteEnabled := true cfg := createTestConfig(walDir, server.URL) cfg.IsDeleteWalAfterUpload = &isDeleteEnabled - streamer := NewStreamer(cfg, logger.GetLogger()) + apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) + streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -173,7 +172,8 @@ func Test_UploadSegment_DeleteDisabled_FileKeptAfterUpload(t *testing.T) { isDeleteDisabled := false cfg := createTestConfig(walDir, server.URL) cfg.IsDeleteWalAfterUpload = &isDeleteDisabled - streamer := NewStreamer(cfg, logger.GetLogger()) + apiClient := api.NewClient(server.URL, cfg.Token, logger.GetLogger()) + streamer := NewStreamer(cfg, apiClient, logger.GetLogger()) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -198,8 +198,7 @@ func Test_UploadSegment_ServerReturns500_FileKeptInQueue(t *testing.T) { })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -223,8 +222,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) { })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -239,8 +237,7 @@ func Test_ProcessQueue_EmptyDirectory_NoUploads(t *testing.T) { func Test_Run_ContextCancelled_StopsImmediately(t *testing.T) { walDir := createTestWalDir(t) - cfg := createTestConfig(walDir, "http://localhost:0") - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, "http://localhost:0") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -269,17 +266,16 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusConflict) - resp := uploadErrorResponse{ - Error: "gap_detected", - ExpectedSegmentName: "000000010000000100000003", - ReceivedSegmentName: segmentName, + resp := map[string]string{ + "error": "gap_detected", + "expectedSegmentName": "000000010000000100000003", + "receivedSegmentName": segmentName, } _ = json.NewEncoder(w).Encode(resp) })) defer server.Close() - cfg := createTestConfig(walDir, server.URL) - streamer := NewStreamer(cfg, logger.GetLogger()) + streamer := newTestStreamer(walDir, server.URL) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -292,6 +288,13 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { assert.NoError(t, err, "segment file should not be deleted on gap detection") } +func newTestStreamer(walDir, serverURL string) *Streamer { + cfg := createTestConfig(walDir, serverURL) + apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger()) + + return NewStreamer(cfg, apiClient, logger.GetLogger()) +} + func createTestWalDir(t *testing.T) string { t.Helper() diff --git a/agent/internal/logger/logger.go b/agent/internal/logger/logger.go index bab697f..7aa6a63 100644 --- a/agent/internal/logger/logger.go +++ b/agent/internal/logger/logger.go @@ -1,45 +1,53 @@ package logger import ( + "fmt" + "io" "log/slog" "os" "sync" "time" ) +const logFileName = "databasus.log" + var ( loggerInstance *slog.Logger once sync.Once ) -func Init(isDebug bool) { - level := slog.LevelInfo - if isDebug { - level = slog.LevelDebug - } - - once.Do(func() { - loggerInstance = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: level, - ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { - if a.Key == slog.TimeKey { - a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05")) - } - if a.Key == slog.LevelKey { - return slog.Attr{} - } - - return a - }, - })) - }) -} - -// GetLogger returns a singleton slog.Logger that logs to the console func GetLogger() *slog.Logger { - if loggerInstance == nil { - Init(false) - } + once.Do(func() { + initialize() + }) return loggerInstance } + +func initialize() { + writer := buildWriter() + + loggerInstance = slog.New(slog.NewTextHandler(writer, &slog.HandlerOptions{ + Level: slog.LevelInfo, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + if a.Key == slog.TimeKey { + a.Value = slog.StringValue(time.Now().Format("2006/01/02 15:04:05")) + } + if a.Key == slog.LevelKey { + return slog.Attr{} + } + + return a + }, + })) +} + +func buildWriter() io.Writer { + f, err := os.OpenFile(logFileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to open %s for logging: %v\n", logFileName, err) + return os.Stdout + } + + return io.MultiWriter(os.Stdout, f) +}