mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
FEATURE (backups): Allow single backup download to avoid exhausting of server throughput
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
backups_core "databasus-backend/internal/features/backups/backups/core"
|
||||
backups_download "databasus-backend/internal/features/backups/backups/download"
|
||||
"databasus-backend/internal/features/databases"
|
||||
users_middleware "databasus-backend/internal/features/users/middleware"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -199,14 +202,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup using a download token
|
||||
// @Description Download the backup file for the specified backup using a download token.
|
||||
// @Description
|
||||
// @Description **Download Concurrency Control:**
|
||||
// @Description - Only one download per user is allowed at a time
|
||||
// @Description - If a download is already in progress, returns 409 Conflict
|
||||
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
|
||||
// @Description - Browser cancellations automatically release the download lock
|
||||
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Param token query string true "Download token"
|
||||
// @Success 200 {file} file
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 409 {object} map[string]string "Download already in progress"
|
||||
// @Failure 500 {object} map[string]string
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
token := ctx.Query("token")
|
||||
@@ -215,7 +226,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get backup ID from URL
|
||||
backupIDParam := ctx.Param("id")
|
||||
backupID, err := uuid.Parse(backupIDParam)
|
||||
if err != nil {
|
||||
@@ -225,11 +235,20 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
|
||||
downloadToken, err := c.backupService.ValidateDownloadToken(token)
|
||||
if err != nil {
|
||||
if err == backups_download.ErrDownloadAlreadyInProgress {
|
||||
ctx.JSON(
|
||||
http.StatusConflict,
|
||||
gin.H{
|
||||
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
|
||||
},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Verify token is for the requested backup
|
||||
if downloadToken.BackupID != backupID {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
|
||||
return
|
||||
@@ -239,18 +258,24 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
downloadToken.BackupID,
|
||||
)
|
||||
if err != nil {
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
cancelHeartbeat()
|
||||
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
|
||||
if err := fileReader.Close(); err != nil {
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
|
||||
|
||||
filename := c.generateBackupFilename(backup, database)
|
||||
|
||||
// Set Content-Length for progress tracking
|
||||
if backup.BackupSizeMb > 0 {
|
||||
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
|
||||
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
|
||||
@@ -268,7 +293,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Write audit log after successful download
|
||||
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
|
||||
}
|
||||
|
||||
@@ -334,3 +358,17 @@ func sanitizeFilename(name string) string {
|
||||
|
||||
return string(result)
|
||||
}
|
||||
|
||||
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
|
||||
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
c.backupService.RefreshDownloadLock(userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -950,6 +950,107 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func Test_ConcurrentDownloadPrevention(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var token1Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token1Response,
|
||||
)
|
||||
|
||||
var token2Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token2Response,
|
||||
)
|
||||
|
||||
downloadInProgress := make(chan bool, 1)
|
||||
downloadComplete := make(chan bool, 1)
|
||||
|
||||
go func() {
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/backups/%s/file?token=%s",
|
||||
backup.ID.String(),
|
||||
token1Response.Token,
|
||||
),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
downloadComplete <- true
|
||||
}()
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
service := GetBackupService()
|
||||
if !service.IsDownloadInProgress(owner.UserID) {
|
||||
t.Log("Warning: First download completed before we could test concurrency")
|
||||
<-downloadComplete
|
||||
return
|
||||
}
|
||||
|
||||
downloadInProgress <- true
|
||||
|
||||
resp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token2Response.Token),
|
||||
"",
|
||||
http.StatusConflict,
|
||||
)
|
||||
|
||||
var errorResponse map[string]string
|
||||
err := json.Unmarshal(resp.Body, &errorResponse)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, errorResponse["error"], "download already in progress")
|
||||
|
||||
<-downloadComplete
|
||||
<-downloadInProgress
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
var token3Response backups_download.GenerateDownloadTokenResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/download-token", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
&token3Response,
|
||||
)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file?token=%s", backup.ID.String(), token3Response.Token),
|
||||
"",
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
t.Log("Database:", database.Name)
|
||||
t.Log(
|
||||
"Successfully prevented concurrent downloads and allowed subsequent downloads after completion",
|
||||
)
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"databasus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var downloadTokenRepository = &DownloadTokenRepository{}
|
||||
|
||||
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
|
||||
|
||||
var downloadTokenService = &DownloadTokenService{
|
||||
downloadTokenRepository,
|
||||
logger.GetLogger(),
|
||||
downloadTracker,
|
||||
}
|
||||
|
||||
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
|
||||
|
||||
@@ -9,8 +9,9 @@ import (
|
||||
)
|
||||
|
||||
type DownloadTokenService struct {
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
repository *DownloadTokenRepository
|
||||
logger *slog.Logger
|
||||
downloadTracker *DownloadTracker
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
|
||||
@@ -50,15 +51,32 @@ func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken,
|
||||
return nil, errors.New("token expired")
|
||||
}
|
||||
|
||||
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dt.Used = true
|
||||
if err := s.repository.Update(dt); err != nil {
|
||||
s.logger.Error("Failed to mark token as used", "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
|
||||
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
|
||||
return dt, nil
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTracker.ReleaseDownloadLock(userID)
|
||||
s.logger.Info("Released download lock", "userId", userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTracker.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *DownloadTokenService) CleanExpiredTokens() error {
|
||||
now := time.Now().UTC()
|
||||
if err := s.repository.DeleteExpired(now); err != nil {
|
||||
|
||||
@@ -0,0 +1,66 @@
|
||||
package backups_download
|
||||
|
||||
import (
|
||||
cache_utils "databasus-backend/internal/util/cache"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/valkey-io/valkey-go"
|
||||
)
|
||||
|
||||
const (
|
||||
downloadLockPrefix = "backup_download_lock:"
|
||||
downloadLockTTL = 5 * time.Second
|
||||
downloadLockValue = "1"
|
||||
downloadHeartbeatDelay = 3 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
|
||||
)
|
||||
|
||||
type DownloadTracker struct {
|
||||
cache *cache_utils.CacheUtil[string]
|
||||
}
|
||||
|
||||
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
|
||||
return &DownloadTracker{
|
||||
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
|
||||
key := userID.String()
|
||||
|
||||
existingLock := t.cache.Get(key)
|
||||
if existingLock != nil {
|
||||
return ErrDownloadAlreadyInProgress
|
||||
}
|
||||
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
value := downloadLockValue
|
||||
t.cache.Set(key, &value)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
key := userID.String()
|
||||
t.cache.Invalidate(key)
|
||||
}
|
||||
|
||||
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
key := userID.String()
|
||||
existingLock := t.cache.Get(key)
|
||||
return existingLock != nil
|
||||
}
|
||||
|
||||
func GetDownloadHeartbeatInterval() time.Duration {
|
||||
return downloadHeartbeatDelay
|
||||
}
|
||||
@@ -522,6 +522,18 @@ func (s *BackupService) WriteAuditLogForDownload(
|
||||
)
|
||||
}
|
||||
|
||||
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.RefreshDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
|
||||
s.downloadTokenService.ReleaseDownloadLock(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
|
||||
return s.downloadTokenService.IsDownloadInProgress(userID)
|
||||
}
|
||||
|
||||
func (s *BackupService) generateBackupFilename(
|
||||
backup *backups_core.Backup,
|
||||
database *databases.Database,
|
||||
|
||||
Reference in New Issue
Block a user