FEATURE (backups): Allow single backup download to avoid exhausting of server throughput

This commit is contained in:
Rostislav Dugin
2026-01-14 13:05:48 +03:00
parent f319a497b3
commit b60a0cc170
6 changed files with 250 additions and 11 deletions

View File

@@ -1,12 +1,15 @@
package backups
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -199,14 +202,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Description Download the backup file for the specified backup using a download token.
// @Description
// @Description **Download Concurrency Control:**
// @Description - Only one download per user is allowed at a time
// @Description - If a download is already in progress, returns 409 Conflict
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
// @Description - Browser cancellations automatically release the download lock
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 409 {object} map[string]string "Download already in progress"
// @Failure 500 {object} map[string]string
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
@@ -215,7 +226,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
@@ -225,11 +235,20 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
@@ -239,18 +258,24 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken.BackupID,
)
if err != nil {
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer func() {
cancelHeartbeat()
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
if err := fileReader.Close(); err != nil {
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
@@ -268,7 +293,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
@@ -334,3 +358,17 @@ func sanitizeFilename(name string) string {
return string(result)
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
c.backupService.RefreshDownloadLock(userID)
}
}
}

View File

@@ -950,6 +950,107 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.True(t, foundCancelLog, "Cancel audit log should be created")
}
func Test_ConcurrentDownloadPrevention(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var token1Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token1Response,
)
var token2Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token2Response,
)
downloadInProgress := make(chan bool, 1)
downloadComplete := make(chan bool, 1)
go func() {
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf(
"/api/v1/backups/%s/file?token=%s",
backup.ID.String(),
token1Response.Token,
),
"",
http.StatusOK,
)
downloadComplete <- true
}()
time.Sleep(50 * time.Millisecond)
service := GetBackupService()
if !service.IsDownloadInProgress(owner.UserID) {
t.Log("Warning: First download completed before we could test concurrency")
<-downloadComplete
return
}
downloadInProgress <- true
resp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
"",
http.StatusConflict,
)
var errorResponse map[string]string
err := json.Unmarshal(resp.Body, &errorResponse)
assert.NoError(t, err)
assert.Contains(t, errorResponse["error"], "download already in progress")
<-downloadComplete
<-downloadInProgress
time.Sleep(100 * time.Millisecond)
var token3Response backups_download.GenerateDownloadTokenResponse
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
&token3Response,
)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
"",
http.StatusOK,
)
t.Log("Database:", database.Name)
t.Log(
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
)
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}

View File

@@ -1,14 +1,18 @@
package backups_download
import (
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
downloadTracker,
}
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{

View File

@@ -9,8 +9,9 @@ import (
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
repository *DownloadTokenRepository
logger *slog.Logger
downloadTracker *DownloadTracker
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
@@ -50,15 +51,32 @@ func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken,
return nil, errors.New("token expired")
}
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
return nil, err
}
dt.Used = true
if err := s.repository.Update(dt); err != nil {
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
return dt, nil
}
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTracker.RefreshDownloadLock(userID)
}
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTracker.ReleaseDownloadLock(userID)
s.logger.Info("Released download lock", "userId", userID)
}
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTracker.IsDownloadInProgress(userID)
}
func (s *DownloadTokenService) CleanExpiredTokens() error {
now := time.Now().UTC()
if err := s.repository.DeleteExpired(now); err != nil {

View File

@@ -0,0 +1,66 @@
package backups_download
import (
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
downloadLockPrefix = "backup_download_lock:"
downloadLockTTL = 5 * time.Second
downloadLockValue = "1"
downloadHeartbeatDelay = 3 * time.Second
)
var (
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
)
type DownloadTracker struct {
cache *cache_utils.CacheUtil[string]
}
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
return &DownloadTracker{
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
}
}
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
key := userID.String()
existingLock := t.cache.Get(key)
if existingLock != nil {
return ErrDownloadAlreadyInProgress
}
value := downloadLockValue
t.cache.Set(key, &value)
return nil
}
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
key := userID.String()
value := downloadLockValue
t.cache.Set(key, &value)
}
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
key := userID.String()
t.cache.Invalidate(key)
}
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
key := userID.String()
existingLock := t.cache.Get(key)
return existingLock != nil
}
func GetDownloadHeartbeatInterval() time.Duration {
return downloadHeartbeatDelay
}

View File

@@ -522,6 +522,18 @@ func (s *BackupService) WriteAuditLogForDownload(
)
}
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTokenService.RefreshDownloadLock(userID)
}
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTokenService.ReleaseDownloadLock(userID)
}
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTokenService.IsDownloadInProgress(userID)
}
func (s *BackupService) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,