FIX (restores): Restore via stream instead of downloading backup to local storage

This commit is contained in:
Rostislav Dugin
2026-01-02 16:06:46 +03:00
parent 58ae86ff7a
commit 5a89558cf6
14 changed files with 583 additions and 493 deletions

View File

@@ -335,11 +335,6 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlD
"--verbose",
}
// Add parallel jobs based on CPU count
if pg.CpuCount > 1 {
args = append(args, "-j", strconv.Itoa(pg.CpuCount))
}
for _, schema := range pg.IncludeSchemas {
args = append(args, "-n", schema)
}

View File

@@ -202,8 +202,10 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
defer container.DB.Close()
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS schema_a;
CREATE SCHEMA IF NOT EXISTS schema_b;
DROP SCHEMA IF EXISTS schema_a CASCADE;
DROP SCHEMA IF EXISTS schema_b CASCADE;
CREATE SCHEMA schema_a;
CREATE SCHEMA schema_b;
CREATE TABLE schema_a.table_a (id INT, data TEXT);
CREATE TABLE schema_b.table_b (id INT, data TEXT);
INSERT INTO schema_a.table_a VALUES (1, 'data_a');

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/features/backups/backups"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
@@ -35,18 +36,6 @@ import (
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -250,42 +239,122 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
assert.True(t, found, "Audit log for restore not found")
}
func Test_RestoreBackup_InsufficientDiskSpace_ReturnsError(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
// Update backup size to 10 TB via repository
repo := &backups.BackupRepository{}
backup.BackupSizeMb = 10485760.0 // 10 TB in MB
err := repo.Save(backup)
assert.NoError(t, err)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
tests := []struct {
name string
dbType databases.DatabaseType
cpuCount int
expectDiskValidated bool
}{
{
name: "PostgreSQL_CPU4_SpaceValidated",
dbType: databases.DatabaseTypePostgres,
cpuCount: 4,
expectDiskValidated: true,
},
{
name: "PostgreSQL_CPU1_SpaceNotValidated",
dbType: databases.DatabaseTypePostgres,
cpuCount: 1,
expectDiskValidated: false,
},
{
name: "MySQL_SpaceNotValidated",
dbType: databases.DatabaseTypeMysql,
cpuCount: 3,
expectDiskValidated: false,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
bodyStr := string(testResp.Body)
assert.Contains(t, bodyStr, "is required")
assert.Contains(t, bodyStr, "is available")
assert.Contains(t, bodyStr, "disk space")
var backup *backups.Backup
var request RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
request = RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
CpuCount: tc.cpuCount,
},
}
} else {
mysqlDB := createTestMySQLDatabase("Test MySQL DB", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "password",
},
}
}
// Set huge backup size (10 TB) that would fail disk validation if checked
repo := &backups.BackupRepository{}
backup.BackupSizeMb = 10485760.0
err := repo.Save(backup)
assert.NoError(t, err)
expectedStatus := http.StatusOK
if tc.expectDiskValidated {
expectedStatus = http.StatusBadRequest
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
expectedStatus,
)
bodyStr := string(testResp.Body)
if tc.expectDiskValidated {
assert.Contains(t, bodyStr, "is required")
assert.Contains(t, bodyStr, "is available")
assert.Contains(t, bodyStr, "disk space")
} else {
assert.Contains(t, bodyStr, "restore started successfully")
}
})
}
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func createTestDatabaseWithBackupForRestore(
@@ -359,6 +428,53 @@ func createTestDatabase(
return &database
}
func createTestMySQLDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypeMysql,
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Port: 3306,
Username: "root",
Password: "password",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf(
"Failed to create MySQL database. Status: %d, Body: %s",
w.Code,
w.Body.String(),
),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,

View File

@@ -129,7 +129,7 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
// Validate disk space before starting restore
if err := s.validateDiskSpace(backup); err != nil {
if err := s.validateDiskSpace(backup, requestDTO); err != nil {
return err
}
@@ -369,7 +369,24 @@ func (s *RestoreService) validateVersionCompatibility(
return nil
}
func (s *RestoreService) validateDiskSpace(backup *backups.Backup) error {
func (s *RestoreService) validateDiskSpace(
backup *backups.Backup,
requestDTO RestoreBackupRequest,
) error {
// Only validate disk space for PostgreSQL when file-based restore is needed:
// - CPU > 1 (parallel jobs require file)
// - IsExcludeExtensions (TOC filtering requires file)
// Other databases and PostgreSQL with CPU=1 without extension exclusion stream directly
if requestDTO.PostgresqlDatabase == nil {
return nil
}
needsFileBased := requestDTO.PostgresqlDatabase.CpuCount > 1 ||
requestDTO.PostgresqlDatabase.IsExcludeExtensions
if !needsFileBased {
return nil
}
diskUsage, err := s.diskService.GetDiskUsage()
if err != nil {
return fmt.Errorf("failed to check disk space: %w", err)

View File

@@ -27,7 +27,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -134,11 +133,16 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
}
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMariadbRestore(
ctx,
@@ -146,7 +150,7 @@ func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
mariadbBin,
args,
myCnfFile,
tempBackupFile,
rawReader,
backup,
)
}
@@ -157,7 +161,7 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
mariadbBin string,
args []string,
myCnfFile string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -165,16 +169,10 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
cmd := exec.CommandContext(ctx, mariadbBin, fullArgs...)
uc.logger.Info("Executing MariaDB restore command", "command", cmd.String())
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -225,69 +223,6 @@ func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
return nil
}
func (uc *RestoreMariadbBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -358,57 +293,6 @@ port=%d
return myCnfFile, nil
}
func (uc *RestoreMariadbBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMariadbBackupUsecase) handleMariadbRestoreError(
database *databases.Database,
waitErr error,

View File

@@ -13,8 +13,6 @@ import (
"strings"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/encryption"
@@ -25,7 +23,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -149,20 +146,26 @@ func (uc *RestoreMongodbBackupUsecase) restoreFromStorage(
}
}()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMongoRestore(ctx, mongorestoreBin, args, tempBackupFile, backup)
return uc.executeMongoRestore(ctx, mongorestoreBin, args, rawReader, backup)
}
func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
ctx context.Context,
mongorestoreBin string,
args []string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
cmd := exec.CommandContext(ctx, mongorestoreBin, args...)
@@ -183,16 +186,10 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
safeArgs,
)
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -232,69 +229,6 @@ func (uc *RestoreMongodbBackupUsecase) executeMongoRestore(
return nil
}
func (uc *RestoreMongodbBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.archive.gz")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMongodbBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -332,57 +266,6 @@ func (uc *RestoreMongodbBackupUsecase) setupDecryption(
return decryptReader, nil
}
func (uc *RestoreMongodbBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMongodbBackupUsecase) handleMongoRestoreError(
waitErr error,
stderrOutput []byte,

View File

@@ -27,7 +27,6 @@ import (
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"databasus-backend/internal/util/tools"
)
@@ -134,13 +133,18 @@ func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
}
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
// Stream backup directly from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to download backup: %w", err)
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer cleanupFunc()
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
return uc.executeMysqlRestore(ctx, database, mysqlBin, args, myCnfFile, tempBackupFile, backup)
return uc.executeMysqlRestore(ctx, database, mysqlBin, args, myCnfFile, rawReader, backup)
}
func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
@@ -149,7 +153,7 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
mysqlBin string,
args []string,
myCnfFile string,
backupFile string,
backupReader io.ReadCloser,
backup *backups.Backup,
) error {
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
@@ -157,16 +161,10 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
cmd := exec.CommandContext(ctx, mysqlBin, fullArgs...)
uc.logger.Info("Executing MySQL restore command", "command", cmd.String())
backupFileHandle, err := os.Open(backupFile)
if err != nil {
return fmt.Errorf("failed to open backup file: %w", err)
}
defer func() { _ = backupFileHandle.Close() }()
var inputReader io.Reader = backupFileHandle
var inputReader io.Reader = backupReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
decryptReader, err := uc.setupDecryption(backupReader, backup)
if err != nil {
return fmt.Errorf("failed to setup decryption: %w", err)
}
@@ -217,69 +215,6 @@ func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
return nil
}
func (uc *RestoreMysqlBackupUsecase) downloadBackupToTempFile(
ctx context.Context,
backup *backups.Backup,
storage *storages.Storage,
) (string, func(), error) {
err := files_utils.EnsureDirectories([]string{
config.GetEnv().TempFolder,
})
if err != nil {
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
}
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
if err != nil {
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
}
cleanupFunc := func() {
_ = os.RemoveAll(tempDir)
}
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
uc.logger.Info(
"Downloading backup file from storage to temporary file",
"backupId", backup.ID,
"tempFile", tempBackupFile,
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
tempFile, err := os.Create(tempBackupFile)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
}
defer func() {
if err := tempFile.Close(); err != nil {
uc.logger.Error("Failed to close temporary file", "error", err)
}
}()
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
}
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
return tempBackupFile, cleanupFunc, nil
}
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
reader io.Reader,
backup *backups.Backup,
@@ -348,57 +283,6 @@ port=%d
return myCnfFile, nil
}
func (uc *RestoreMysqlBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 16*1024*1024)
var totalBytesWritten int64
for {
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
default:
}
if config.IsShouldShutdown() {
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
}
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
writeErr = fmt.Errorf("invalid write result")
}
}
if writeErr != nil {
return totalBytesWritten, writeErr
}
if bytesRead != bytesWritten {
return totalBytesWritten, io.ErrShortWrite
}
totalBytesWritten += int64(bytesWritten)
}
if readErr != nil {
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
return totalBytesWritten, nil
}
func (uc *RestoreMysqlBackupUsecase) handleMysqlRestoreError(
database *databases.Database,
waitErr error,

View File

@@ -82,7 +82,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
)
}
// restoreCustomType restores a backup in custom type (-Fc) - legacy type
// restoreCustomType restores a backup in custom type (-Fc)
func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
originalDB *databases.Database,
pgBin string,
@@ -91,7 +91,248 @@ func (uc *RestorePostgresqlBackupUsecase) restoreCustomType(
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
) error {
uc.logger.Info("Restoring backup in custom type (-Fc)", "backupId", backup.ID)
uc.logger.Info(
"Restoring backup in custom type (-Fc)",
"backupId",
backup.ID,
"cpuCount",
pg.CpuCount,
)
// If excluding extensions, we must use file-based restore (requires TOC file generation)
// Also use file-based restore for parallel jobs (multiple CPUs)
if isExcludeExtensions || pg.CpuCount > 1 {
return uc.restoreViaFile(originalDB, pgBin, backup, storage, pg, isExcludeExtensions)
}
// Single CPU without extension exclusion: stream directly via stdin
return uc.restoreViaStdin(originalDB, pgBin, backup, storage, pg)
}
// restoreViaStdin streams backup via stdin for single CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaStdin(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
) error {
uc.logger.Info("Restoring via stdin streaming (CPU=1)", "backupId", backup.ID)
args := []string{
"-Fc", // expect custom type
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose",
"--clean",
"--if-exists",
"--no-owner",
"--no-acl",
}
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
// Monitor for shutdown and cancel context if needed
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
// Create temporary .pgpass file for authentication
fieldEncryptor := util_encryption.GetFieldEncryptor()
decryptedPassword, err := fieldEncryptor.Decrypt(originalDB.ID, pg.Password)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
pgpassFile, err := uc.createTempPgpassFile(pg, decryptedPassword)
if err != nil {
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
}
defer func() {
if pgpassFile != "" {
_ = os.RemoveAll(filepath.Dir(pgpassFile))
}
}()
// Verify .pgpass file was created successfully
if pgpassFile == "" {
return fmt.Errorf("temporary .pgpass file was not created")
}
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return fmt.Errorf("failed to verify .pgpass file: %w", err)
}
// Get backup stream from storage
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
return fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
var backupReader io.Reader = rawReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
// Validate encryption metadata
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
return fmt.Errorf("backup is encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return fmt.Errorf("failed to get master key for decryption: %w", err)
}
// Decode salt and IV from base64
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
return fmt.Errorf("failed to decode encryption salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
return fmt.Errorf("failed to decode encryption IV: %w", err)
}
// Create decryption reader
decryptReader, err := encryption.NewDecryptionReader(
rawReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
return fmt.Errorf("failed to create decryption reader: %w", err)
}
backupReader = decryptReader
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
}
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL restore command via stdin", "command", cmd.String())
// Setup environment variables
uc.setupPgRestoreEnvironment(cmd, pgpassFile, pg)
// Verify executable exists and is accessible
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf(
"PostgreSQL executable not found or not accessible: %s - %w",
pgBin,
err,
)
}
// Create stdin pipe for explicit data pumping
stdinPipe, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("stdin pipe: %w", err)
}
// Get stderr to capture any error output
pgStderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe: %w", err)
}
// Capture stderr in a separate goroutine
stderrCh := make(chan []byte, 1)
go func() {
stderrOutput, _ := io.ReadAll(pgStderr)
stderrCh <- stderrOutput
}()
// Start pg_restore
if err = cmd.Start(); err != nil {
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
}
// Copy backup data to stdin in a separate goroutine with proper error handling
copyErrCh := make(chan error, 1)
go func() {
_, copyErr := io.Copy(stdinPipe, backupReader)
// Close stdin pipe to signal EOF to pg_restore - critical for proper termination
closeErr := stdinPipe.Close()
if copyErr != nil {
copyErrCh <- fmt.Errorf("copy to stdin: %w", copyErr)
} else if closeErr != nil {
copyErrCh <- fmt.Errorf("close stdin: %w", closeErr)
} else {
copyErrCh <- nil
}
}()
// Wait for the restore to finish
waitErr := cmd.Wait()
stderrOutput := <-stderrCh
copyErr := <-copyErrCh
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
// Check for copy errors first - these indicate issues with decryption or data reading
if copyErr != nil {
return fmt.Errorf("failed to stream backup data to pg_restore: %w", copyErr)
}
if waitErr != nil {
if config.IsShouldShutdown() {
return fmt.Errorf("restore cancelled due to shutdown")
}
return uc.handlePgRestoreError(originalDB, waitErr, stderrOutput, pgBin, args, pg)
}
return nil
}
// restoreViaFile downloads backup and uses parallel jobs for multi-CPU restore
func (uc *RestorePostgresqlBackupUsecase) restoreViaFile(
originalDB *databases.Database,
pgBin string,
backup *backups.Backup,
storage *storages.Storage,
pg *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
) error {
uc.logger.Info(
"Restoring via file with parallel jobs",
"backupId",
backup.ID,
"cpuCount",
pg.CpuCount,
)
// Use parallel jobs based on CPU count
// Cap between 1 and 8 to avoid overwhelming the server

View File

@@ -68,23 +68,31 @@ type TestDataItem struct {
func Test_BackupAndRestorePostgresql_RestoreIsSuccesful(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
name string
version string
port string
cpuCount int
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
{"PostgreSQL 18", "18", env.TestPostgres18Port},
{"PostgreSQL 12 (CPU=1 streamed)", "12", env.TestPostgres12Port, 1},
{"PostgreSQL 12 (CPU=4 directory)", "12", env.TestPostgres12Port, 4},
{"PostgreSQL 13 (CPU=1 streamed)", "13", env.TestPostgres13Port, 1},
{"PostgreSQL 13 (CPU=4 directory)", "13", env.TestPostgres13Port, 4},
{"PostgreSQL 14 (CPU=1 streamed)", "14", env.TestPostgres14Port, 1},
{"PostgreSQL 14 (CPU=4 directory)", "14", env.TestPostgres14Port, 4},
{"PostgreSQL 15 (CPU=1 streamed)", "15", env.TestPostgres15Port, 1},
{"PostgreSQL 15 (CPU=4 directory)", "15", env.TestPostgres15Port, 4},
{"PostgreSQL 16 (CPU=1 streamed)", "16", env.TestPostgres16Port, 1},
{"PostgreSQL 16 (CPU=4 directory)", "16", env.TestPostgres16Port, 4},
{"PostgreSQL 17 (CPU=1 streamed)", "17", env.TestPostgres17Port, 1},
{"PostgreSQL 17 (CPU=4 directory)", "17", env.TestPostgres17Port, 4},
{"PostgreSQL 18 (CPU=1 streamed)", "18", env.TestPostgres18Port, 1},
{"PostgreSQL 18 (CPU=4 directory)", "18", env.TestPostgres18Port, 4},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testBackupRestoreForVersion(t, tc.version, tc.port)
testBackupRestoreForVersion(t, tc.version, tc.port, tc.cpuCount)
})
}
}
@@ -361,7 +369,7 @@ func Test_BackupAndRestorePostgresql_WithReadOnlyUser_RestoreIsSuccessful(t *tes
}
}
func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string, cpuCount int) {
container, err := connectToPostgresContainer(pgVersion, port)
assert.NoError(t, err)
defer func() {
@@ -379,10 +387,11 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
storage := storages.CreateTestStorage(workspace.ID)
database := createDatabaseViaAPI(
database := createDatabaseWithCpuCountViaAPI(
t, router, "Test Database", workspace.ID,
container.Host, container.Port,
container.Username, container.Password, container.Database,
cpuCount,
user.Token,
)
@@ -396,7 +405,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb"
newDBName := fmt.Sprintf("restoreddb_%s_cpu%d", pgVersion, cpuCount)
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
@@ -409,10 +418,11 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
assert.NoError(t, err)
defer newDB.Close()
createRestoreViaAPI(
createRestoreWithCpuCountViaAPI(
t, router, backup.ID,
container.Host, container.Port,
container.Username, container.Password, newDBName,
cpuCount,
user.Token,
)
@@ -1258,6 +1268,27 @@ func createDatabaseViaAPI(
password string,
database string,
token string,
) *databases.Database {
return createDatabaseWithCpuCountViaAPI(
t, router, name, workspaceID,
host, port, username, password, database,
1,
token,
)
}
func createDatabaseWithCpuCountViaAPI(
t *testing.T,
router *gin.Engine,
name string,
workspaceID uuid.UUID,
host string,
port int,
username string,
password string,
database string,
cpuCount int,
token string,
) *databases.Database {
request := databases.Database{
Name: name,
@@ -1269,7 +1300,7 @@ func createDatabaseViaAPI(
Username: username,
Password: password,
Database: &database,
CpuCount: 1,
CpuCount: cpuCount,
},
}
@@ -1354,7 +1385,7 @@ func createRestoreViaAPI(
database string,
token string,
) {
createRestoreWithOptionsViaAPI(
createRestoreWithCpuCountViaAPI(
t,
router,
backupID,
@@ -1363,11 +1394,44 @@ func createRestoreViaAPI(
username,
password,
database,
false,
1,
token,
)
}
func createRestoreWithCpuCountViaAPI(
t *testing.T,
router *gin.Engine,
backupID uuid.UUID,
host string,
port int,
username string,
password string,
database string,
cpuCount int,
token string,
) {
request := restores.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Host: host,
Port: port,
Username: username,
Password: password,
Database: &database,
CpuCount: cpuCount,
},
}
test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backupID.String()),
"Bearer "+token,
request,
http.StatusOK,
)
}
func createRestoreWithOptionsViaAPI(
t *testing.T,
router *gin.Engine,

View File

@@ -176,6 +176,13 @@ func getPostgresqlBasePath(
postgresesInstallDir string,
) string {
if envMode == env_utils.EnvModeDevelopment {
// On Windows, PostgreSQL 12 and 13 have issues with piping over restore
if runtime.GOOS == "windows" {
if version == PostgresqlVersion12 || version == PostgresqlVersion13 {
version = PostgresqlVersion14
}
}
return filepath.Join(
postgresesInstallDir,
fmt.Sprintf("postgresql-%s", string(version)),

View File

@@ -50,13 +50,13 @@ const initializeDatabaseTypeData = (db: Database): Database => {
switch (db.type) {
case DatabaseType.POSTGRES:
return { ...base, postgresql: db.postgresql ?? ({ cpuCount: 4 } as PostgresqlDatabase) };
return { ...base, postgresql: db.postgresql ?? ({ cpuCount: 1 } as PostgresqlDatabase) };
case DatabaseType.MYSQL:
return { ...base, mysql: db.mysql ?? ({} as MysqlDatabase) };
case DatabaseType.MARIADB:
return { ...base, mariadb: db.mariadb ?? ({} as MariadbDatabase) };
case DatabaseType.MONGODB:
return { ...base, mongodb: db.mongodb ?? ({ cpuCount: 4 } as MongodbDatabase) };
return { ...base, mongodb: db.mongodb ?? ({ cpuCount: 1 } as MongodbDatabase) };
default:
return db;
}

View File

@@ -78,7 +78,7 @@ export const EditMongoDbSpecificDataComponent = ({
database: result.database,
authDatabase: result.authDatabase,
isHttps: result.useTls,
cpuCount: 4,
cpuCount: 1,
},
};

View File

@@ -82,7 +82,7 @@ export const EditPostgreSqlSpecificDataComponent = ({
password: result.password,
database: result.database,
isHttps: result.isHttps,
cpuCount: 4,
cpuCount: 1,
},
};
@@ -356,34 +356,36 @@ export const EditPostgreSqlSpecificDataComponent = ({
/>
</div>
<div className="mb-5 flex w-full items-center">
<div className="min-w-[150px]">CPU count</div>
<div className="flex items-center">
<InputNumber
min={1}
max={128}
value={editingDatabase.postgresql?.cpuCount}
onChange={(value) => {
if (!editingDatabase.postgresql) return;
{isRestoreMode && (
<div className="mb-5 flex w-full items-center">
<div className="min-w-[150px]">CPU count</div>
<div className="flex items-center">
<InputNumber
min={1}
max={128}
value={editingDatabase.postgresql?.cpuCount}
onChange={(value) => {
if (!editingDatabase.postgresql) return;
setEditingDatabase({
...editingDatabase,
postgresql: { ...editingDatabase.postgresql, cpuCount: value || 1 },
});
setIsConnectionTested(false);
}}
size="small"
className="max-w-[75px] grow"
/>
setEditingDatabase({
...editingDatabase,
postgresql: { ...editingDatabase.postgresql, cpuCount: value || 1 },
});
setIsConnectionTested(false);
}}
size="small"
className="max-w-[75px] grow"
/>
<Tooltip
className="cursor-pointer"
title="Number of CPU cores to use for backup and restore operations. Higher values may speed up operations but use more resources."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
<Tooltip
className="cursor-pointer"
title="Number of CPU cores to use for backup and restore operations. Higher values may speed up operations but use more resources."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
</div>
</div>
)}
<div className="mt-4 mb-1 flex items-center">
<div

View File

@@ -54,11 +54,6 @@ export const ShowPostgreSqlSpecificDataComponent = ({ database }: Props) => {
<div>{database.postgresql?.isHttps ? 'Yes' : 'No'}</div>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">CPU count</div>
<div>{database.postgresql?.cpuCount}</div>
</div>
{!!database.postgresql?.includeSchemas?.length && (
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Include schemas</div>