diff --git a/agent/internal/features/api/idle_timeout_reader.go b/agent/internal/features/api/idle_timeout_reader.go new file mode 100644 index 0000000..6a59ffe --- /dev/null +++ b/agent/internal/features/api/idle_timeout_reader.go @@ -0,0 +1,60 @@ +package api + +import ( + "context" + "fmt" + "io" + "time" +) + +// IdleTimeoutReader wraps an io.Reader and cancels the associated context +// if no bytes are successfully read within the specified timeout duration. +// This detects stalled uploads where the network or source stops transmitting data. +// +// When the idle timeout fires, the reader is also closed (if it implements io.Closer) +// to unblock any goroutine blocked on the underlying Read. +type IdleTimeoutReader struct { + reader io.Reader + timeout time.Duration + cancel context.CancelCauseFunc + timer *time.Timer +} + +// NewIdleTimeoutReader creates a reader that cancels the context via cancel +// if Read does not return any bytes for the given timeout duration. +func NewIdleTimeoutReader(reader io.Reader, timeout time.Duration, cancel context.CancelCauseFunc) *IdleTimeoutReader { + r := &IdleTimeoutReader{ + reader: reader, + timeout: timeout, + cancel: cancel, + } + + r.timer = time.AfterFunc(timeout, func() { + cancel(fmt.Errorf("upload idle timeout: no bytes transmitted for %v", timeout)) + + if closer, ok := reader.(io.Closer); ok { + _ = closer.Close() + } + }) + + return r +} + +func (r *IdleTimeoutReader) Read(p []byte) (int, error) { + n, err := r.reader.Read(p) + + if n > 0 { + r.timer.Reset(r.timeout) + } + + if err != nil { + r.Stop() + } + + return n, err +} + +// Stop cancels the idle timer. Must be called when the reader is no longer needed. +func (r *IdleTimeoutReader) Stop() { + r.timer.Stop() +} diff --git a/agent/internal/features/api/idle_timeout_reader_test.go b/agent/internal/features/api/idle_timeout_reader_test.go new file mode 100644 index 0000000..17d52ff --- /dev/null +++ b/agent/internal/features/api/idle_timeout_reader_test.go @@ -0,0 +1,112 @@ +package api + +import ( + "context" + "fmt" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ReadThroughIdleTimeoutReader_WhenBytesFlowContinuously_DoesNotCancelContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 200*time.Millisecond, cancel) + defer idleReader.Stop() + + go func() { + for range 5 { + _, _ = pw.Write([]byte("data")) + time.Sleep(50 * time.Millisecond) + } + + _ = pw.Close() + }() + + data, err := io.ReadAll(idleReader) + + require.NoError(t, err) + assert.Equal(t, "datadatadatadatadata", string(data)) + assert.NoError(t, ctx.Err(), "context should not be cancelled when bytes flow continuously") +} + +func Test_ReadThroughIdleTimeoutReader_WhenNoBytesTransmitted_CancelsContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, _ := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + defer idleReader.Stop() + + time.Sleep(200 * time.Millisecond) + + assert.Error(t, ctx.Err(), "context should be cancelled when no bytes are transmitted") + assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout") +} + +func Test_ReadThroughIdleTimeoutReader_WhenBytesStopMidStream_CancelsContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + defer idleReader.Stop() + + go func() { + _, _ = pw.Write([]byte("initial")) + // Stop writing — simulate stalled source + }() + + buf := make([]byte, 1024) + n, _ := idleReader.Read(buf) + assert.Equal(t, "initial", string(buf[:n])) + + time.Sleep(200 * time.Millisecond) + + assert.Error(t, ctx.Err(), "context should be cancelled when bytes stop mid-stream") + assert.Contains(t, context.Cause(ctx).Error(), "upload idle timeout") +} + +func Test_StopIdleTimeoutReader_WhenCalledBeforeTimeout_DoesNotCancelContext(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, _ := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 100*time.Millisecond, cancel) + idleReader.Stop() + + time.Sleep(200 * time.Millisecond) + + assert.NoError(t, ctx.Err(), "context should not be cancelled when reader is stopped before timeout") +} + +func Test_ReadThroughIdleTimeoutReader_WhenReaderReturnsError_PropagatesError(t *testing.T) { + ctx, cancel := context.WithCancelCause(t.Context()) + defer cancel(nil) + + pr, pw := io.Pipe() + + idleReader := NewIdleTimeoutReader(pr, 5*time.Second, cancel) + defer idleReader.Stop() + + expectedErr := fmt.Errorf("test read error") + _ = pw.CloseWithError(expectedErr) + + buf := make([]byte, 1024) + _, err := idleReader.Read(buf) + + assert.ErrorIs(t, err, expectedErr) + + // Timer should be stopped after error — context should not be cancelled + time.Sleep(100 * time.Millisecond) + assert.NoError(t, ctx.Err(), "context should not be cancelled after reader error stops the timer") +} diff --git a/agent/internal/features/full_backup/backuper.go b/agent/internal/features/full_backup/backuper.go index e336311..1e12221 100644 --- a/agent/internal/features/full_backup/backuper.go +++ b/agent/internal/features/full_backup/backuper.go @@ -21,9 +21,11 @@ import ( const ( checkInterval = 30 * time.Second retryDelay = 1 * time.Minute - uploadTimeout = 30 * time.Minute + uploadTimeout = 23 * time.Hour ) +var uploadIdleTimeout = 5 * time.Minute + var retryDelayOverride *time.Duration type CmdBuilder func(ctx context.Context) *exec.Cmd @@ -176,16 +178,32 @@ func (backuper *FullBackuper) executeAndUploadBasebackup(ctx context.Context) er // Phase 1: Stream compressed data via io.Pipe directly to the API. pipeReader, pipeWriter := io.Pipe() + defer func() { _ = pipeReader.Close() }() + go backuper.compressAndStream(pipeWriter, stdoutPipe) - uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) - defer cancel() + uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout) + defer timeoutCancel() - uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(uploadCtx, pipeReader) + idleCtx, idleCancel := context.WithCancelCause(uploadCtx) + defer idleCancel(nil) + + idleReader := api.NewIdleTimeoutReader(pipeReader, uploadIdleTimeout, idleCancel) + defer idleReader.Stop() + + uploadResp, uploadErr := backuper.apiClient.UploadBasebackup(idleCtx, idleReader) + + if uploadErr != nil && cmd.Process != nil { + _ = cmd.Process.Kill() + } cmdErr := cmd.Wait() if uploadErr != nil { + if cause := context.Cause(idleCtx); cause != nil { + uploadErr = cause + } + stderrStr := stderrBuf.String() if stderrStr != "" { return fmt.Errorf("upload basebackup: %w (pg_basebackup stderr: %s)", uploadErr, stderrStr) diff --git a/agent/internal/features/full_backup/backuper_test.go b/agent/internal/features/full_backup/backuper_test.go index 7d423a4..610b3fb 100644 --- a/agent/internal/features/full_backup/backuper_test.go +++ b/agent/internal/features/full_backup/backuper_test.go @@ -562,6 +562,68 @@ func Test_RunFullBackup_WhenUploadSucceeds_BodyIsZstdCompressed(t *testing.T) { assert.Equal(t, originalContent, string(decompressed)) } +func Test_RunFullBackup_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) { + server := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case testFullStartPath: + // Server reads body normally — it will block until connection is closed + _, _ = io.ReadAll(r.Body) + writeJSON(w, map[string]string{"backupId": testBackupID}) + default: + w.WriteHeader(http.StatusNotFound) + } + }) + + fb := newTestFullBackuper(server.URL) + fb.cmdBuilder = stallingCmdBuilder(t) + + origIdleTimeout := uploadIdleTimeout + uploadIdleTimeout = 200 * time.Millisecond + defer func() { uploadIdleTimeout = origIdleTimeout }() + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + err := fb.executeAndUploadBasebackup(ctx) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "idle timeout", "error should mention idle timeout") +} + +func stallingCmdBuilder(t *testing.T) CmdBuilder { + t.Helper() + + return func(ctx context.Context) *exec.Cmd { + cmd := exec.CommandContext(ctx, os.Args[0], + "-test.run=TestHelperProcessStalling", + "--", + ) + + cmd.Env = append(os.Environ(), "GO_TEST_HELPER_PROCESS_STALLING=1") + + return cmd + } +} + +func TestHelperProcessStalling(t *testing.T) { + if os.Getenv("GO_TEST_HELPER_PROCESS_STALLING") != "1" { + return + } + + // Write enough data to flush through the zstd encoder's internal buffer (~128KB blocks). + // Without enough data, zstd buffers everything and the pipe never receives bytes. + data := make([]byte, 256*1024) + for i := range data { + data[i] = byte(i) + } + _, _ = os.Stdout.Write(data) + + // Stall with stdout open — the compress goroutine blocks on its next read. + // The parent process will kill us when the context is cancelled. + time.Sleep(time.Hour) + os.Exit(0) +} + func newTestServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { t.Helper() diff --git a/agent/internal/features/wal/streamer.go b/agent/internal/features/wal/streamer.go index 7ae5337..d5b5702 100644 --- a/agent/internal/features/wal/streamer.go +++ b/agent/internal/features/wal/streamer.go @@ -18,6 +18,8 @@ import ( "databasus-agent/internal/features/api" ) +var uploadIdleTimeout = 5 * time.Minute + const ( pollInterval = 10 * time.Second uploadTimeout = 5 * time.Minute @@ -122,16 +124,27 @@ func (s *Streamer) uploadSegment(ctx context.Context, segmentName string) error filePath := filepath.Join(s.cfg.PgWalDir, segmentName) pr, pw := io.Pipe() + defer func() { _ = pr.Close() }() go s.compressAndStream(pw, filePath) - uploadCtx, cancel := context.WithTimeout(ctx, uploadTimeout) - defer cancel() + uploadCtx, timeoutCancel := context.WithTimeout(ctx, uploadTimeout) + defer timeoutCancel() + + idleCtx, idleCancel := context.WithCancelCause(uploadCtx) + defer idleCancel(nil) + + idleReader := api.NewIdleTimeoutReader(pr, uploadIdleTimeout, idleCancel) + defer idleReader.Stop() s.log.Info("Uploading WAL segment", "segment", segmentName) - result, err := s.apiClient.UploadWalSegment(uploadCtx, segmentName, pr) + result, err := s.apiClient.UploadWalSegment(idleCtx, segmentName, idleReader) if err != nil { + if cause := context.Cause(idleCtx); cause != nil { + return fmt.Errorf("upload WAL segment: %w", cause) + } + return err } diff --git a/agent/internal/features/wal/streamer_test.go b/agent/internal/features/wal/streamer_test.go index fc161a8..5647cbf 100644 --- a/agent/internal/features/wal/streamer_test.go +++ b/agent/internal/features/wal/streamer_test.go @@ -2,6 +2,7 @@ package wal import ( "context" + "crypto/rand" "encoding/json" "io" "net/http" @@ -9,6 +10,7 @@ import ( "os" "path/filepath" "sync" + "sync/atomic" "testing" "time" @@ -287,6 +289,49 @@ func Test_UploadSegment_ServerReturns409_FileNotDeleted(t *testing.T) { assert.NoError(t, err, "segment file should not be deleted on gap detection") } +func Test_UploadSegment_WhenUploadStalls_FailsWithIdleTimeout(t *testing.T) { + walDir := createTestWalDir(t) + + // Use incompressible random data to ensure TCP buffers fill up + segmentContent := make([]byte, 1024*1024) + _, err := rand.Read(segmentContent) + require.NoError(t, err) + + writeTestSegment(t, walDir, "000000010000000100000001", segmentContent) + + var requestReceived atomic.Bool + handlerDone := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestReceived.Store(true) + + // Read one byte then stall — simulates a network stall + buf := make([]byte, 1) + _, _ = r.Body.Read(buf) + <-handlerDone + })) + defer server.Close() + defer close(handlerDone) + + origIdleTimeout := uploadIdleTimeout + uploadIdleTimeout = 200 * time.Millisecond + defer func() { uploadIdleTimeout = origIdleTimeout }() + + streamer := newTestStreamer(walDir, server.URL) + + ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) + defer cancel() + + uploadErr := streamer.uploadSegment(ctx, "000000010000000100000001") + + assert.Error(t, uploadErr, "upload should fail when stalled") + assert.True(t, requestReceived.Load(), "server should have received the request") + assert.Contains(t, uploadErr.Error(), "idle timeout", "error should mention idle timeout") + + _, statErr := os.Stat(filepath.Join(walDir, "000000010000000100000001")) + assert.NoError(t, statErr, "segment file should remain in queue after idle timeout") +} + func newTestStreamer(walDir, serverURL string) *Streamer { cfg := createTestConfig(walDir, serverURL) apiClient := api.NewClient(serverURL, cfg.Token, logger.GetLogger())