diff --git a/backend/internal/features/backups/backups/controller.go b/backend/internal/features/backups/backups/controller.go index d569729..0b5d034 100644 --- a/backend/internal/features/backups/backups/controller.go +++ b/backend/internal/features/backups/backups/controller.go @@ -1,12 +1,15 @@ package backups import ( + "context" backups_core "databasus-backend/internal/features/backups/backups/core" + backups_download "databasus-backend/internal/features/backups/backups/download" "databasus-backend/internal/features/databases" users_middleware "databasus-backend/internal/features/users/middleware" "fmt" "io" "net/http" + "time" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -199,14 +202,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) { // GetFile // @Summary Download a backup file -// @Description Download the backup file for the specified backup using a download token +// @Description Download the backup file for the specified backup using a download token. +// @Description +// @Description **Download Concurrency Control:** +// @Description - Only one download per user is allowed at a time +// @Description - If a download is already in progress, returns 409 Conflict +// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat +// @Description - Browser cancellations automatically release the download lock +// @Description - Server crashes are handled via automatic cache expiry (5 seconds) // @Tags backups // @Param id path string true "Backup ID" // @Param token query string true "Download token" // @Success 200 {file} file -// @Failure 400 -// @Failure 401 -// @Failure 500 +// @Failure 400 {object} map[string]string +// @Failure 401 {object} map[string]string +// @Failure 409 {object} map[string]string "Download already in progress" +// @Failure 500 {object} map[string]string // @Router /backups/{id}/file [get] func (c *BackupController) GetFile(ctx *gin.Context) { token := ctx.Query("token") @@ -215,7 +226,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) { return } - // Get backup ID from URL backupIDParam := ctx.Param("id") backupID, err := uuid.Parse(backupIDParam) if err != nil { @@ -225,11 +235,20 @@ func (c *BackupController) GetFile(ctx *gin.Context) { downloadToken, err := c.backupService.ValidateDownloadToken(token) if err != nil { + if err == backups_download.ErrDownloadAlreadyInProgress { + ctx.JSON( + http.StatusConflict, + gin.H{ + "error": "download already in progress for this user. Please wait until previous download completed or cancel it", + }, + ) + return + } + ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"}) return } - // Verify token is for the requested backup if downloadToken.BackupID != backupID { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"}) return @@ -239,18 +258,24 @@ func (c *BackupController) GetFile(ctx *gin.Context) { downloadToken.BackupID, ) if err != nil { + c.backupService.ReleaseDownloadLock(downloadToken.UserID) ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } + + heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background()) defer func() { + cancelHeartbeat() + c.backupService.ReleaseDownloadLock(downloadToken.UserID) if err := fileReader.Close(); err != nil { fmt.Printf("Error closing file reader: %v\n", err) } }() + go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID) + filename := c.generateBackupFilename(backup, database) - // Set Content-Length for progress tracking if backup.BackupSizeMb > 0 { sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024) ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes)) @@ -268,7 +293,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) { return } - // Write audit log after successful download c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database) } @@ -334,3 +358,17 @@ func sanitizeFilename(name string) string { return string(result) } + +func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) { + ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval()) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.backupService.RefreshDownloadLock(userID) + } + } +} diff --git a/backend/internal/features/backups/backups/controller_test.go b/backend/internal/features/backups/backups/controller_test.go index 17425d3..ed7d437 100644 --- a/backend/internal/features/backups/backups/controller_test.go +++ b/backend/internal/features/backups/backups/controller_test.go @@ -950,6 +950,107 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) { assert.True(t, foundCancelLog, "Cancel audit log should be created") } +func Test_ConcurrentDownloadPrevention(t *testing.T) { + router := createTestRouter() + owner := users_testing.CreateTestUser(users_enums.UserRoleMember) + workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router) + + database, backup := createTestDatabaseWithBackups(workspace, owner, router) + + var token1Response backups_download.GenerateDownloadTokenResponse + test_utils.MakePostRequestAndUnmarshal( + t, + router, + fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()), + "Bearer "+owner.Token, + nil, + http.StatusOK, + &token1Response, + ) + + var token2Response backups_download.GenerateDownloadTokenResponse + test_utils.MakePostRequestAndUnmarshal( + t, + router, + fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()), + "Bearer "+owner.Token, + nil, + http.StatusOK, + &token2Response, + ) + + downloadInProgress := make(chan bool, 1) + downloadComplete := make(chan bool, 1) + + go func() { + test_utils.MakeGetRequest( + t, + router, + fmt.Sprintf( + "/api/v1/backups/%s/file?token=%s", + backup.ID.String(), + token1Response.Token, + ), + "", + http.StatusOK, + ) + downloadComplete <- true + }() + + time.Sleep(50 * time.Millisecond) + + service := GetBackupService() + if !service.IsDownloadInProgress(owner.UserID) { + t.Log("Warning: First download completed before we could test concurrency") + <-downloadComplete + return + } + + downloadInProgress <- true + + resp := test_utils.MakeGetRequest( + t, + router, + fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token), + "", + http.StatusConflict, + ) + + var errorResponse map[string]string + err := json.Unmarshal(resp.Body, &errorResponse) + assert.NoError(t, err) + assert.Contains(t, errorResponse["error"], "download already in progress") + + <-downloadComplete + <-downloadInProgress + + time.Sleep(100 * time.Millisecond) + + var token3Response backups_download.GenerateDownloadTokenResponse + test_utils.MakePostRequestAndUnmarshal( + t, + router, + fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()), + "Bearer "+owner.Token, + nil, + http.StatusOK, + &token3Response, + ) + + test_utils.MakeGetRequest( + t, + router, + fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token), + "", + http.StatusOK, + ) + + t.Log("Database:", database.Name) + t.Log( + "Successfully prevented concurrent downloads and allowed subsequent downloads after completion", + ) +} + func createTestRouter() *gin.Engine { return CreateTestRouter() } diff --git a/backend/internal/features/backups/backups/download/di.go b/backend/internal/features/backups/backups/download/di.go index befc0cc..5a2c0e6 100644 --- a/backend/internal/features/backups/backups/download/di.go +++ b/backend/internal/features/backups/backups/download/di.go @@ -1,14 +1,18 @@ package backups_download import ( + cache_utils "databasus-backend/internal/util/cache" "databasus-backend/internal/util/logger" ) var downloadTokenRepository = &DownloadTokenRepository{} +var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient()) + var downloadTokenService = &DownloadTokenService{ downloadTokenRepository, logger.GetLogger(), + downloadTracker, } var downloadTokenBackgroundService = &DownloadTokenBackgroundService{ diff --git a/backend/internal/features/backups/backups/download/service.go b/backend/internal/features/backups/backups/download/service.go index b25a6b6..0f4d313 100644 --- a/backend/internal/features/backups/backups/download/service.go +++ b/backend/internal/features/backups/backups/download/service.go @@ -9,8 +9,9 @@ import ( ) type DownloadTokenService struct { - repository *DownloadTokenRepository - logger *slog.Logger + repository *DownloadTokenRepository + logger *slog.Logger + downloadTracker *DownloadTracker } func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) { @@ -50,15 +51,32 @@ func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, return nil, errors.New("token expired") } + if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil { + return nil, err + } + dt.Used = true if err := s.repository.Update(dt); err != nil { s.logger.Error("Failed to mark token as used", "error", err) } - s.logger.Info("Token validated and consumed", "backupId", dt.BackupID) + s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID) return dt, nil } +func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) { + s.downloadTracker.RefreshDownloadLock(userID) +} + +func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) { + s.downloadTracker.ReleaseDownloadLock(userID) + s.logger.Info("Released download lock", "userId", userID) +} + +func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool { + return s.downloadTracker.IsDownloadInProgress(userID) +} + func (s *DownloadTokenService) CleanExpiredTokens() error { now := time.Now().UTC() if err := s.repository.DeleteExpired(now); err != nil { diff --git a/backend/internal/features/backups/backups/download/tracking.go b/backend/internal/features/backups/backups/download/tracking.go new file mode 100644 index 0000000..64fc6eb --- /dev/null +++ b/backend/internal/features/backups/backups/download/tracking.go @@ -0,0 +1,66 @@ +package backups_download + +import ( + cache_utils "databasus-backend/internal/util/cache" + "errors" + "time" + + "github.com/google/uuid" + "github.com/valkey-io/valkey-go" +) + +const ( + downloadLockPrefix = "backup_download_lock:" + downloadLockTTL = 5 * time.Second + downloadLockValue = "1" + downloadHeartbeatDelay = 3 * time.Second +) + +var ( + ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user") +) + +type DownloadTracker struct { + cache *cache_utils.CacheUtil[string] +} + +func NewDownloadTracker(client valkey.Client) *DownloadTracker { + return &DownloadTracker{ + cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix), + } +} + +func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error { + key := userID.String() + + existingLock := t.cache.Get(key) + if existingLock != nil { + return ErrDownloadAlreadyInProgress + } + + value := downloadLockValue + t.cache.Set(key, &value) + + return nil +} + +func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) { + key := userID.String() + value := downloadLockValue + t.cache.Set(key, &value) +} + +func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) { + key := userID.String() + t.cache.Invalidate(key) +} + +func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool { + key := userID.String() + existingLock := t.cache.Get(key) + return existingLock != nil +} + +func GetDownloadHeartbeatInterval() time.Duration { + return downloadHeartbeatDelay +} diff --git a/backend/internal/features/backups/backups/service.go b/backend/internal/features/backups/backups/service.go index ccac3a9..6611399 100644 --- a/backend/internal/features/backups/backups/service.go +++ b/backend/internal/features/backups/backups/service.go @@ -522,6 +522,18 @@ func (s *BackupService) WriteAuditLogForDownload( ) } +func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) { + s.downloadTokenService.RefreshDownloadLock(userID) +} + +func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) { + s.downloadTokenService.ReleaseDownloadLock(userID) +} + +func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool { + return s.downloadTokenService.IsDownloadInProgress(userID) +} + func (s *BackupService) generateBackupFilename( backup *backups_core.Backup, database *databases.Database,