mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 08:41:58 +02:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
878fad5747 | ||
|
|
6ff3096695 |
@@ -5,6 +5,7 @@ import (
|
||||
"postgresus-backend/internal/config"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/period"
|
||||
"time"
|
||||
)
|
||||
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_models "postgresus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -700,7 +701,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
|
||||
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"time"
|
||||
)
|
||||
@@ -25,6 +26,7 @@ var backupService = &BackupService{
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -31,6 +32,7 @@ type BackupService struct {
|
||||
notificationSender NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
|
||||
@@ -284,7 +286,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(backup.ID); deleteErr != nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
@@ -545,7 +547,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
@@ -599,7 +601,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(backup.ID)
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -56,11 +57,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil, // auditLogService
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
@@ -103,11 +105,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil, // auditLogService
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
@@ -127,11 +130,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil, // auditLogService
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
|
||||
@@ -15,12 +15,13 @@ import (
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -40,8 +41,9 @@ const (
|
||||
)
|
||||
|
||||
type CreatePostgresqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
logger *slog.Logger
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database
|
||||
@@ -166,7 +168,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -440,7 +442,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *encryption.EncryptionWriter, BackupMetadata, error) {
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
|
||||
metadata := BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
@@ -449,12 +451,12 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
return storageWriter, nil, metadata, nil
|
||||
}
|
||||
|
||||
salt, err := encryption.GenerateSalt()
|
||||
salt, err := backup_encryption.GenerateSalt()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := encryption.GenerateNonce()
|
||||
nonce, err := backup_encryption.GenerateNonce()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
@@ -464,7 +466,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
encWriter, err := encryption.NewEncryptionWriter(
|
||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
||||
storageWriter,
|
||||
masterKey,
|
||||
backupID,
|
||||
@@ -486,7 +488,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
||||
encryptionWriter *encryption.EncryptionWriter,
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
saveErrCh chan error,
|
||||
) {
|
||||
@@ -510,7 +512,7 @@ func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
|
||||
encryptionWriter *encryption.EncryptionWriter,
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
) error {
|
||||
encryptionCloseErrCh := make(chan error, 1)
|
||||
|
||||
@@ -2,12 +2,14 @@ package usecases_postgresql
|
||||
|
||||
import (
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -769,6 +770,71 @@ func createTestDatabaseViaAPI(
|
||||
return &database
|
||||
}
|
||||
|
||||
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
testDbName := "test_db"
|
||||
plainPassword := "my-super-secret-password-123"
|
||||
request := Database{
|
||||
Name: "Test Database",
|
||||
WorkspaceID: &workspace.ID,
|
||||
Type: DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: plainPassword,
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
var createdDatabase Database
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusCreated,
|
||||
&createdDatabase,
|
||||
)
|
||||
|
||||
repository := &DatabaseRepository{}
|
||||
databaseFromDB, err := repository.FindByID(createdDatabase.ID)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, databaseFromDB)
|
||||
assert.NotNil(t, databaseFromDB.Postgresql)
|
||||
|
||||
assert.True(
|
||||
t,
|
||||
strings.HasPrefix(databaseFromDB.Postgresql.Password, "enc:"),
|
||||
"Password should be encrypted in database with 'enc:' prefix, got: %s",
|
||||
databaseFromDB.Postgresql.Password,
|
||||
)
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decryptedPassword, err := encryptor.Decrypt(
|
||||
databaseFromDB.ID,
|
||||
databaseFromDB.Postgresql.Password,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plainPassword, decryptedPassword,
|
||||
"Decrypted password should match original plaintext password")
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+createdDatabase.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -815,7 +881,15 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "original-password-secret", database.Postgresql.Password)
|
||||
// Verify password is encrypted
|
||||
assert.True(t, strings.HasPrefix(database.Postgresql.Password, "enc:"),
|
||||
"Password should be encrypted in database")
|
||||
|
||||
// Verify it can be decrypted back to original
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Postgresql.Password)
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
"regexp"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -59,11 +59,15 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
|
||||
func (p *PostgresqlDatabase) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return testSingleDatabaseConnection(logger, ctx, p)
|
||||
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) HideSensitiveData() {
|
||||
@@ -87,19 +91,42 @@ func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
databaseID uuid.UUID,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if p.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Password = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
|
||||
func testSingleDatabaseConnection(
|
||||
logger *slog.Logger,
|
||||
ctx context.Context,
|
||||
postgresDb *PostgresqlDatabase,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
// For single database backup, we need to connect to the specific database
|
||||
if postgresDb.Database == nil || *postgresDb.Database == "" {
|
||||
return errors.New("database name is required for single database backup (pg_dump)")
|
||||
}
|
||||
|
||||
// Decrypt password if needed
|
||||
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Build connection string for the specific database
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
|
||||
|
||||
// Test connection
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
@@ -182,7 +209,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
|
||||
}
|
||||
|
||||
// buildConnectionStringForDB builds connection string for specific database
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
|
||||
sslMode := "disable"
|
||||
if p.IsHttps {
|
||||
sslMode = "require"
|
||||
@@ -192,106 +219,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
p.Host,
|
||||
p.Port,
|
||||
p.Username,
|
||||
p.Password,
|
||||
password,
|
||||
dbName,
|
||||
sslMode,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
|
||||
if len(extensions) == 0 {
|
||||
return nil
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, error) {
|
||||
if encryptor == nil {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
if p.Database == nil || *p.Database == "" {
|
||||
return errors.New("database name is required for installing extensions")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Build connection string for the specific database
|
||||
connStr := buildConnectionStringForDB(p, *p.Database)
|
||||
|
||||
// Connect to database
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
fmt.Println("failed to close connection: %w", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check which extensions are already installed
|
||||
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check installed extensions: %w", err)
|
||||
}
|
||||
|
||||
// Install missing extensions
|
||||
for _, extension := range extensions {
|
||||
if contains(installedExtensions, string(extension)) {
|
||||
continue // Extension already installed
|
||||
}
|
||||
|
||||
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
|
||||
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInstalledExtensions queries the database for currently installed extensions
|
||||
func (p *PostgresqlDatabase) getInstalledExtensions(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
) ([]string, error) {
|
||||
query := "SELECT extname FROM pg_extension"
|
||||
|
||||
rows, err := conn.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var extensions []string
|
||||
for rows.Next() {
|
||||
var extname string
|
||||
|
||||
if err := rows.Scan(&extname); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan extension name: %w", err)
|
||||
}
|
||||
|
||||
extensions = append(extensions, extname)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
|
||||
}
|
||||
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
// installExtension installs a single PostgreSQL extension
|
||||
func (p *PostgresqlDatabase) installExtension(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
extensionName string,
|
||||
) error {
|
||||
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
|
||||
|
||||
_, err := conn.Exec(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// contains checks if a string slice contains a specific string
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -19,6 +20,7 @@ var databaseService = &DatabaseService{
|
||||
[]DatabaseCopyListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
var databaseController = &DatabaseController{
|
||||
|
||||
@@ -2,6 +2,7 @@ package databases
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -11,7 +12,11 @@ type DatabaseValidator interface {
|
||||
}
|
||||
|
||||
type DatabaseConnector interface {
|
||||
TestConnection(logger *slog.Logger) error
|
||||
TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error
|
||||
|
||||
HideSensitiveData()
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -56,14 +57,24 @@ func (d *Database) ValidateUpdate(old, new Database) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) TestConnection(logger *slog.Logger) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger)
|
||||
func (d *Database) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
|
||||
}
|
||||
|
||||
func (d *Database) HideSensitiveData() {
|
||||
d.getSpecificDatabase().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Update(incoming *Database) {
|
||||
d.Name = incoming.Name
|
||||
d.Type = incoming.Type
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -26,6 +27,7 @@ type DatabaseService struct {
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *DatabaseService) AddDbCreationListener(
|
||||
@@ -65,6 +67,10 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
database, err = s.dbRepository.Save(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -118,6 +124,10 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.dbRepository.Save(existingDatabase)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -250,7 +260,7 @@ func (s *DatabaseService) TestDatabaseConnection(
|
||||
return errors.New("insufficient permissions to test connection for this database")
|
||||
}
|
||||
|
||||
err = database.TestConnection(s.logger)
|
||||
err = database.TestConnection(s.logger, s.fieldEncryptor)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
database.LastBackupErrorMessage = &lastSaveError
|
||||
@@ -294,7 +304,7 @@ func (s *DatabaseService) TestDatabaseConnectionDirect(
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
return usingDatabase.TestConnection(s.logger)
|
||||
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetDatabaseByID(
|
||||
|
||||
@@ -453,70 +453,6 @@ func Test_CrossWorkspaceSecurity_CannotAccessNotifierFromAnotherWorkspace(t *tes
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
GetNotifierController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
|
||||
}
|
||||
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Notifier " + uuid.New().String(),
|
||||
NotifierType: NotifierTypeWebhook,
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
|
||||
env := config.GetEnv()
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Telegram Notifier " + uuid.New().String(),
|
||||
NotifierType: NotifierTypeTelegram,
|
||||
TelegramNotifier: &telegram_notifier.TelegramNotifier{
|
||||
BotToken: env.TestTelegramBotToken,
|
||||
TargetChatID: env.TestTelegramChatID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
|
||||
assert.Equal(t, expected.Name, actual.Name)
|
||||
assert.Equal(t, expected.NotifierType, actual.NotifierType)
|
||||
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
|
||||
}
|
||||
|
||||
func deleteNotifier(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
notifierID, workspaceID uuid.UUID,
|
||||
token string,
|
||||
) {
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -553,7 +489,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "original-bot-token-12345", notifier.TelegramNotifier.BotToken)
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.TelegramNotifier.BotToken),
|
||||
"BotToken should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
|
||||
assert.Equal(t, "original-bot-token-12345", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "", notifier.TelegramNotifier.BotToken)
|
||||
@@ -592,7 +534,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "original-password-secret", notifier.EmailNotifier.SMTPPassword)
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.EmailNotifier.SMTPPassword),
|
||||
"SMTPPassword should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "", notifier.EmailNotifier.SMTPPassword)
|
||||
@@ -625,7 +573,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "xoxb-original-slack-token", notifier.SlackNotifier.BotToken)
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.SlackNotifier.BotToken),
|
||||
"BotToken should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
|
||||
assert.Equal(t, "xoxb-original-slack-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "", notifier.SlackNotifier.BotToken)
|
||||
@@ -656,11 +610,17 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(
|
||||
assert.True(
|
||||
t,
|
||||
"https://discord.com/api/webhooks/123/original-token",
|
||||
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
|
||||
"WebhookURL should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.DiscordNotifier.ChannelWebhookURL,
|
||||
)
|
||||
assert.Equal(t, "https://discord.com/api/webhooks/123/original-token", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.Equal(t, "", notifier.DiscordNotifier.ChannelWebhookURL)
|
||||
@@ -691,10 +651,16 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.TeamsNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted in DB",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
|
||||
assert.Equal(
|
||||
t,
|
||||
"https://outlook.office.com/webhook/original-token",
|
||||
notifier.TeamsNotifier.WebhookURL,
|
||||
decrypted,
|
||||
)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
|
||||
@@ -813,3 +779,263 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
createNotifier func(workspaceID uuid.UUID) *Notifier
|
||||
verifySensitiveEncryption func(t *testing.T, notifier *Notifier)
|
||||
}{
|
||||
{
|
||||
name: "Telegram Notifier - BotToken encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Telegram",
|
||||
NotifierType: NotifierTypeTelegram,
|
||||
TelegramNotifier: &telegram_notifier.TelegramNotifier{
|
||||
BotToken: "plain-telegram-token-123",
|
||||
TargetChatID: "123456789",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.TelegramNotifier.BotToken),
|
||||
"BotToken should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
|
||||
assert.Equal(t, "plain-telegram-token-123", decrypted)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Email Notifier - SMTPPassword encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Email",
|
||||
NotifierType: NotifierTypeEmail,
|
||||
EmailNotifier: &email_notifier.EmailNotifier{
|
||||
TargetEmail: "test@example.com",
|
||||
SMTPHost: "smtp.example.com",
|
||||
SMTPPort: 587,
|
||||
SMTPUser: "user@example.com",
|
||||
SMTPPassword: "plain-smtp-password-456",
|
||||
From: "noreply@example.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.EmailNotifier.SMTPPassword),
|
||||
"SMTPPassword should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
|
||||
assert.Equal(t, "plain-smtp-password-456", decrypted)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Slack Notifier - BotToken encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Slack",
|
||||
NotifierType: NotifierTypeSlack,
|
||||
SlackNotifier: &slack_notifier.SlackNotifier{
|
||||
BotToken: "plain-slack-token-789",
|
||||
TargetChatID: "C0123456789",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.SlackNotifier.BotToken),
|
||||
"BotToken should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
|
||||
assert.Equal(t, "plain-slack-token-789", decrypted)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Discord Notifier - WebhookURL encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Discord",
|
||||
NotifierType: NotifierTypeDiscord,
|
||||
DiscordNotifier: &discord_notifier.DiscordNotifier{
|
||||
ChannelWebhookURL: "https://discord.com/api/webhooks/123/abc",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(
|
||||
t,
|
||||
notifier.ID,
|
||||
notifier.DiscordNotifier.ChannelWebhookURL,
|
||||
)
|
||||
assert.Equal(t, "https://discord.com/api/webhooks/123/abc", decrypted)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Teams Notifier - WebhookURL encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Teams",
|
||||
NotifierType: NotifierTypeTeams,
|
||||
TeamsNotifier: &teams_notifier.TeamsNotifier{
|
||||
WebhookURL: "https://outlook.office.com/webhook/test123",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.TeamsNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://outlook.office.com/webhook/test123", decrypted)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Webhook Notifier - WebhookURL encrypted",
|
||||
createNotifier: func(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Webhook",
|
||||
NotifierType: NotifierTypeWebhook,
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.example.com/test456",
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
|
||||
assert.True(
|
||||
t,
|
||||
isEncrypted(notifier.WebhookNotifier.WebhookURL),
|
||||
"WebhookURL should be encrypted",
|
||||
)
|
||||
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
|
||||
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
// Create notifier via API (plaintext credentials)
|
||||
var createdNotifier Notifier
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/notifiers",
|
||||
"Bearer "+owner.Token,
|
||||
tc.createNotifier(workspace.ID),
|
||||
http.StatusOK,
|
||||
&createdNotifier,
|
||||
)
|
||||
|
||||
// Read from DB directly (bypass service layer)
|
||||
repository := &NotifierRepository{}
|
||||
notifierFromDB, err := repository.FindByID(createdNotifier.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify encryption
|
||||
tc.verifySensitiveEncryption(t, notifierFromDB)
|
||||
|
||||
// Cleanup
|
||||
deleteNotifier(t, router, createdNotifier.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
GetNotifierController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
|
||||
}
|
||||
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Notifier " + uuid.New().String(),
|
||||
NotifierType: NotifierTypeWebhook,
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
|
||||
WebhookMethod: webhook_notifier.WebhookMethodPOST,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
|
||||
env := config.GetEnv()
|
||||
return &Notifier{
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "Test Telegram Notifier " + uuid.New().String(),
|
||||
NotifierType: NotifierTypeTelegram,
|
||||
TelegramNotifier: &telegram_notifier.TelegramNotifier{
|
||||
BotToken: env.TestTelegramBotToken,
|
||||
TargetChatID: env.TestTelegramChatID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
|
||||
assert.Equal(t, expected.Name, actual.Name)
|
||||
assert.Equal(t, expected.NotifierType, actual.NotifierType)
|
||||
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
|
||||
}
|
||||
|
||||
func deleteNotifier(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
notifierID, workspaceID uuid.UUID,
|
||||
token string,
|
||||
) {
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
func isEncrypted(value string) bool {
|
||||
return len(value) > 4 && value[:4] == "enc:"
|
||||
}
|
||||
|
||||
func decryptField(t *testing.T, notifierID uuid.UUID, encryptedValue string) string {
|
||||
encryptor := GetNotifierService().fieldEncryptor
|
||||
decrypted, err := encryptor.Decrypt(notifierID, encryptedValue)
|
||||
assert.NoError(t, err)
|
||||
return decrypted
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package notifiers
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -12,6 +13,7 @@ var notifierService = &NotifierService{
|
||||
logger.GetLogger(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var notifierController = &NotifierController{
|
||||
notifierService,
|
||||
|
||||
@@ -1,11 +1,21 @@
|
||||
package notifiers
|
||||
|
||||
import "log/slog"
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
Send(logger *slog.Logger, heading string, message string) error
|
||||
Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error
|
||||
|
||||
Validate() error
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
HideSensitiveData()
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
|
||||
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
|
||||
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
|
||||
return "notifiers"
|
||||
}
|
||||
|
||||
func (n *Notifier) Validate() error {
|
||||
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
|
||||
return n.getSpecificNotifier().Validate()
|
||||
return n.getSpecificNotifier().Validate(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
err := n.getSpecificNotifier().Send(logger, heading, message)
|
||||
func (n *Notifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
|
||||
|
||||
if err != nil {
|
||||
lastSendError := err.Error()
|
||||
@@ -58,6 +64,10 @@ func (n *Notifier) HideSensitiveData() {
|
||||
n.getSpecificNotifier().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Update(incoming *Notifier) {
|
||||
n.Name = incoming.Name
|
||||
n.NotifierType = incoming.NotifierType
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
|
||||
return "discord_notifiers"
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Validate() error {
|
||||
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (d *DiscordNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal Discord payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
|
||||
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -81,3 +92,14 @@ func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
|
||||
d.ChannelWebhookURL = incoming.ChannelWebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
d.ChannelWebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -34,7 +35,7 @@ func (e *EmailNotifier) TableName() string {
|
||||
return "email_notifiers"
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Validate() error {
|
||||
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if e.TargetEmail == "" {
|
||||
return errors.New("target email is required")
|
||||
}
|
||||
@@ -55,7 +56,22 @@ func (e *EmailNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (e *EmailNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
// Decrypt SMTP password if provided
|
||||
var smtpPassword string
|
||||
if e.SMTPPassword != "" {
|
||||
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
|
||||
}
|
||||
smtpPassword = decrypted
|
||||
}
|
||||
|
||||
// Compose email
|
||||
from := e.From
|
||||
if from == "" {
|
||||
@@ -85,7 +101,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
timeout := DefaultTimeout
|
||||
|
||||
// Determine if authentication is required
|
||||
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
|
||||
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
|
||||
|
||||
// Handle different port scenarios
|
||||
if e.SMTPPort == ImplicitTLSPort {
|
||||
@@ -116,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Set up authentication only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -179,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Authenticate only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -229,3 +245,14 @@ func (e *EmailNotifier) Update(incoming *EmailNotifier) {
|
||||
e.SMTPPassword = incoming.SMTPPassword
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if e.SMTPPassword != "" {
|
||||
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
|
||||
}
|
||||
e.SMTPPassword = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,7 +24,7 @@ type SlackNotifier struct {
|
||||
|
||||
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
|
||||
|
||||
func (s *SlackNotifier) Validate() error {
|
||||
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
|
||||
func (s *SlackNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
full := fmt.Sprintf("*%s*", heading)
|
||||
|
||||
if message != "" {
|
||||
@@ -80,7 +90,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", "Bearer "+s.BotToken)
|
||||
req.Header.Set("Authorization", "Bearer "+botToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -144,3 +154,14 @@ func (s *SlackNotifier) Update(incoming *SlackNotifier) {
|
||||
s.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
s.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
|
||||
return "teams_notifiers"
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Validate() error {
|
||||
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL == "" {
|
||||
return errors.New("webhook_url is required")
|
||||
}
|
||||
u, err := url.Parse(n.WebhookURL)
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return errors.New("invalid webhook_url")
|
||||
}
|
||||
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
|
||||
}
|
||||
|
||||
type cardAttachment struct {
|
||||
ContentType string `json:"contentType"`
|
||||
Content interface{} `json:"content"`
|
||||
ContentType string `json:"contentType"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
@@ -43,11 +50,20 @@ type payload struct {
|
||||
Attachments []cardAttachment `json:"attachments,omitempty"`
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
|
||||
if err := n.Validate(); err != nil {
|
||||
func (n *TeamsNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
if err := n.Validate(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
card := map[string]any{
|
||||
"type": "AdaptiveCard",
|
||||
"version": "1.4",
|
||||
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(p)
|
||||
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
|
||||
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -104,3 +120,14 @@ func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
|
||||
n.WebhookURL = incoming.WebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
n.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
|
||||
return "telegram_notifiers"
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Validate() error {
|
||||
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *TelegramNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("chat_id", t.TargetChatID)
|
||||
@@ -93,3 +104,14 @@ func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
|
||||
t.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
t.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
|
||||
return "webhook_notifiers"
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Validate() error {
|
||||
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *WebhookNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
|
||||
t.WebhookURL,
|
||||
webhookURL,
|
||||
url.QueryEscape(heading),
|
||||
url.QueryEscape(message),
|
||||
)
|
||||
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
|
||||
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send POST webhook: %w", err)
|
||||
}
|
||||
@@ -110,3 +121,14 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
t.WebhookURL = incoming.WebhookURL
|
||||
t.WebhookMethod = incoming.WebhookMethod
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -17,6 +18,7 @@ type NotifierService struct {
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *NotifierService) SaveNotifier(
|
||||
@@ -46,7 +48,11 @@ func (s *NotifierService) SaveNotifier(
|
||||
|
||||
existingNotifier.Update(notifier)
|
||||
|
||||
if err := existingNotifier.Validate(); err != nil {
|
||||
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -63,7 +69,11 @@ func (s *NotifierService) SaveNotifier(
|
||||
} else {
|
||||
notifier.WorkspaceID = workspaceID
|
||||
|
||||
if err := notifier.Validate(); err != nil {
|
||||
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := notifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -175,7 +185,7 @@ func (s *NotifierService) SendTestNotification(
|
||||
return errors.New("insufficient permissions to test notifier in this workspace")
|
||||
}
|
||||
|
||||
err = notifier.Send(s.logger, "Test message", "This is a test message")
|
||||
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -205,16 +215,24 @@ func (s *NotifierService) SendTestNotificationToNotifier(
|
||||
|
||||
existingNotifier.Update(notifier)
|
||||
|
||||
if err := existingNotifier.Validate(); err != nil {
|
||||
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingNotifier = existingNotifier
|
||||
} else {
|
||||
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingNotifier = notifier
|
||||
}
|
||||
|
||||
return usingNotifier.Send(s.logger, "Test message", "This is a test message")
|
||||
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
|
||||
}
|
||||
|
||||
func (s *NotifierService) SendNotification(
|
||||
@@ -233,7 +251,7 @@ func (s *NotifierService) SendNotification(
|
||||
return
|
||||
}
|
||||
|
||||
err = notifiedFromDb.Send(s.logger, title, message)
|
||||
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
notifiedFromDb.LastSendError = &errMsg
|
||||
|
||||
@@ -29,6 +29,7 @@ import (
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_models "postgresus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
@@ -309,6 +310,7 @@ func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *backups.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
@@ -338,7 +340,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
|
||||
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
@@ -209,7 +210,8 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
"encrypted",
|
||||
backup.Encryption == backups_config.BackupEncryptionEncrypted,
|
||||
)
|
||||
rawReader, err := storage.GetFile(backup.ID)
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
|
||||
@@ -3,10 +3,14 @@ package storages
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
|
||||
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
@@ -14,6 +18,7 @@ import (
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -438,6 +443,386 @@ func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *test
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}
|
||||
|
||||
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storageType StorageType
|
||||
createStorage func(workspaceID uuid.UUID) *Storage
|
||||
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
|
||||
verifySensitiveData func(t *testing.T, storage *Storage)
|
||||
verifyHiddenData func(t *testing.T, storage *Storage)
|
||||
}{
|
||||
{
|
||||
name: "S3 Storage",
|
||||
storageType: StorageTypeS3,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Test S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "test-bucket",
|
||||
S3Region: "us-east-1",
|
||||
S3AccessKey: "original-access-key",
|
||||
S3SecretKey: "original-secret-key",
|
||||
S3Endpoint: "https://s3.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Updated S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "updated-bucket",
|
||||
S3Region: "us-west-2",
|
||||
S3AccessKey: "",
|
||||
S3SecretKey: "",
|
||||
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
|
||||
"S3AccessKey should be encrypted with 'enc:' prefix")
|
||||
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
|
||||
"S3SecretKey should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-access-key", accessKey)
|
||||
|
||||
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-secret-key", secretKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
|
||||
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Local Storage",
|
||||
storageType: StorageTypeLocal,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Test Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Updated Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NAS Storage",
|
||||
storageType: StorageTypeNAS,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeNAS,
|
||||
Name: "Test NAS Storage",
|
||||
NASStorage: &nas_storage.NASStorage{
|
||||
Host: "nas.example.com",
|
||||
Port: 445,
|
||||
Share: "backups",
|
||||
Username: "testuser",
|
||||
Password: "original-password",
|
||||
UseSSL: false,
|
||||
Domain: "WORKGROUP",
|
||||
Path: "/test",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeNAS,
|
||||
Name: "Updated NAS Storage",
|
||||
NASStorage: &nas_storage.NASStorage{
|
||||
Host: "nas2.example.com",
|
||||
Port: 445,
|
||||
Share: "backups2",
|
||||
Username: "testuser2",
|
||||
Password: "",
|
||||
UseSSL: true,
|
||||
Domain: "WORKGROUP2",
|
||||
Path: "/test2",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
|
||||
"Password should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password", password)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.NASStorage.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure Blob Storage (Connection String)",
|
||||
storageType: StorageTypeAzureBlob,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Test Azure Blob Storage",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
|
||||
ConnectionString: "original-connection-string",
|
||||
ContainerName: "test-container",
|
||||
Endpoint: "",
|
||||
Prefix: "backups/",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Updated Azure Blob Storage",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
|
||||
ConnectionString: "",
|
||||
ContainerName: "updated-container",
|
||||
Endpoint: "https://custom.blob.core.windows.net",
|
||||
Prefix: "backups2/",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
|
||||
"ConnectionString should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
connectionString, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.AzureBlobStorage.ConnectionString,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-connection-string", connectionString)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure Blob Storage (Account Key)",
|
||||
storageType: StorageTypeAzureBlob,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Test Azure Blob with Account Key",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
|
||||
AccountName: "testaccount",
|
||||
AccountKey: "original-account-key",
|
||||
ContainerName: "test-container",
|
||||
Endpoint: "",
|
||||
Prefix: "backups/",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Updated Azure Blob with Account Key",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
|
||||
AccountName: "updatedaccount",
|
||||
AccountKey: "",
|
||||
ContainerName: "updated-container",
|
||||
Endpoint: "https://custom.blob.core.windows.net",
|
||||
Prefix: "backups2/",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
|
||||
"AccountKey should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
accountKey, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.AzureBlobStorage.AccountKey,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-account-key", accountKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Google Drive Storage",
|
||||
storageType: StorageTypeGoogleDrive,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeGoogleDrive,
|
||||
Name: "Test Google Drive Storage",
|
||||
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
|
||||
ClientID: "original-client-id",
|
||||
ClientSecret: "original-client-secret",
|
||||
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeGoogleDrive,
|
||||
Name: "Updated Google Drive Storage",
|
||||
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
|
||||
ClientID: "updated-client-id",
|
||||
ClientSecret: "",
|
||||
TokenJSON: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
|
||||
"ClientSecret should be encrypted with 'enc:' prefix")
|
||||
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
|
||||
"TokenJSON should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
clientSecret, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.GoogleDriveStorage.ClientSecret,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-client-secret", clientSecret)
|
||||
|
||||
tokenJSON, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.GoogleDriveStorage.TokenJSON,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
|
||||
tokenJSON,
|
||||
)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
|
||||
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
// Phase 1: Create storage with sensitive data
|
||||
initialStorage := tc.createStorage(workspace.ID)
|
||||
var createdStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*initialStorage,
|
||||
http.StatusOK,
|
||||
&createdStorage,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, createdStorage.ID)
|
||||
assert.Equal(t, initialStorage.Name, createdStorage.Name)
|
||||
|
||||
// Phase 2: Verify sensitive data is encrypted in repository after creation
|
||||
repository := &StorageRepository{}
|
||||
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
|
||||
assert.NoError(t, err)
|
||||
tc.verifySensitiveData(t, storageFromDBAfterCreate)
|
||||
|
||||
// Phase 3: Read via service - sensitive data should be hidden
|
||||
var retrievedStorage Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&retrievedStorage,
|
||||
)
|
||||
|
||||
tc.verifyHiddenData(t, &retrievedStorage)
|
||||
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
|
||||
|
||||
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
|
||||
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
|
||||
var updateResponse Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*updatedStorage,
|
||||
http.StatusOK,
|
||||
&updateResponse,
|
||||
)
|
||||
|
||||
// Verify non-sensitive fields were updated
|
||||
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
|
||||
|
||||
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
|
||||
storageFromDB, err := repository.FindByID(createdStorage.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify original sensitive data is still present in DB
|
||||
tc.verifySensitiveData(t, storageFromDB)
|
||||
|
||||
// Verify non-sensitive fields were updated in DB
|
||||
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
|
||||
|
||||
// Additional verification: Check via GET that data is still hidden
|
||||
var finalRetrieved Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&finalRetrieved,
|
||||
)
|
||||
tc.verifyHiddenData(t, &finalRetrieved)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
@@ -485,158 +870,3 @@ func deleteStorage(
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storageType StorageType
|
||||
createStorage func(workspaceID uuid.UUID) *Storage
|
||||
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
|
||||
verifySensitiveData func(t *testing.T, storage *Storage)
|
||||
verifyHiddenData func(t *testing.T, storage *Storage)
|
||||
}{
|
||||
{
|
||||
name: "S3 Storage",
|
||||
storageType: StorageTypeS3,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Test S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "test-bucket",
|
||||
S3Region: "us-east-1",
|
||||
S3AccessKey: "original-access-key",
|
||||
S3SecretKey: "original-secret-key",
|
||||
S3Endpoint: "https://s3.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Updated S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "updated-bucket",
|
||||
S3Region: "us-west-2",
|
||||
S3AccessKey: "",
|
||||
S3SecretKey: "",
|
||||
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "original-access-key", storage.S3Storage.S3AccessKey)
|
||||
assert.Equal(t, "original-secret-key", storage.S3Storage.S3SecretKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
|
||||
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Local Storage",
|
||||
storageType: StorageTypeLocal,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Test Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Updated Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
// Phase 1: Create storage with sensitive data
|
||||
initialStorage := tc.createStorage(workspace.ID)
|
||||
var createdStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*initialStorage,
|
||||
http.StatusOK,
|
||||
&createdStorage,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, createdStorage.ID)
|
||||
assert.Equal(t, initialStorage.Name, createdStorage.Name)
|
||||
|
||||
// Phase 2: Read via service - sensitive data should be hidden
|
||||
var retrievedStorage Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&retrievedStorage,
|
||||
)
|
||||
|
||||
tc.verifyHiddenData(t, &retrievedStorage)
|
||||
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
|
||||
|
||||
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
|
||||
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
|
||||
var updateResponse Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*updatedStorage,
|
||||
http.StatusOK,
|
||||
&updateResponse,
|
||||
)
|
||||
|
||||
// Verify non-sensitive fields were updated
|
||||
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
|
||||
|
||||
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
|
||||
repository := &StorageRepository{}
|
||||
storageFromDB, err := repository.FindByID(createdStorage.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify original sensitive data is still present in DB
|
||||
tc.verifySensitiveData(t, storageFromDB)
|
||||
|
||||
// Verify non-sensitive fields were updated in DB
|
||||
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
|
||||
|
||||
// Additional verification: Check via GET that data is still hidden
|
||||
var finalRetrieved Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&finalRetrieved,
|
||||
)
|
||||
tc.verifyHiddenData(t, &finalRetrieved)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package storages
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
var storageRepository = &StorageRepository{}
|
||||
@@ -10,6 +11,7 @@ var storageService = &StorageService{
|
||||
storageRepository,
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var storageController = &StorageController{
|
||||
storageService,
|
||||
|
||||
@@ -3,20 +3,28 @@ package storages
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type StorageFileSaver interface {
|
||||
SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error
|
||||
SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error
|
||||
|
||||
GetFile(fileID uuid.UUID) (io.ReadCloser, error)
|
||||
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
|
||||
|
||||
DeleteFile(fileID uuid.UUID) error
|
||||
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
|
||||
|
||||
Validate() error
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
TestConnection() error
|
||||
TestConnection(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
HideSensitiveData()
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -28,8 +29,13 @@ type Storage struct {
|
||||
AzureBlobStorage *azure_blob_storage.AzureBlobStorage `json:"azureBlobStorage" gorm:"foreignKey:StorageID"`
|
||||
}
|
||||
|
||||
func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
|
||||
err := s.getSpecificStorage().SaveFile(logger, fileID, file)
|
||||
func (s *Storage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
err := s.getSpecificStorage().SaveFile(encryptor, logger, fileID, file)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
s.LastSaveError = &lastSaveError
|
||||
@@ -41,15 +47,18 @@ func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
return s.getSpecificStorage().GetFile(fileID)
|
||||
func (s *Storage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
return s.getSpecificStorage().GetFile(encryptor, fileID)
|
||||
}
|
||||
|
||||
func (s *Storage) DeleteFile(fileID uuid.UUID) error {
|
||||
return s.getSpecificStorage().DeleteFile(fileID)
|
||||
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
|
||||
}
|
||||
|
||||
func (s *Storage) Validate() error {
|
||||
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.Type == "" {
|
||||
return errors.New("storage type is required")
|
||||
}
|
||||
@@ -58,17 +67,21 @@ func (s *Storage) Validate() error {
|
||||
return errors.New("storage name is required")
|
||||
}
|
||||
|
||||
return s.getSpecificStorage().Validate()
|
||||
return s.getSpecificStorage().Validate(encryptor)
|
||||
}
|
||||
|
||||
func (s *Storage) TestConnection() error {
|
||||
return s.getSpecificStorage().TestConnection()
|
||||
func (s *Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
return s.getSpecificStorage().TestConnection(encryptor)
|
||||
}
|
||||
|
||||
func (s *Storage) HideSensitiveData() {
|
||||
s.getSpecificStorage().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
|
||||
}
|
||||
|
||||
func (s *Storage) Update(incoming *Storage) {
|
||||
s.Name = incoming.Name
|
||||
s.Type = incoming.Type
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"strconv"
|
||||
"testing"
|
||||
@@ -147,13 +148,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
|
||||
t.Run("Test_TestConnection_ConnectionSucceeds", func(t *testing.T) {
|
||||
err := tc.storage.TestConnection()
|
||||
err := tc.storage.TestConnection(encryptor)
|
||||
assert.NoError(t, err, "TestConnection should succeed")
|
||||
})
|
||||
|
||||
t.Run("Test_TestValidation_ValidationSucceeds", func(t *testing.T) {
|
||||
err := tc.storage.Validate()
|
||||
err := tc.storage.Validate(encryptor)
|
||||
assert.NoError(t, err, "Validate should succeed")
|
||||
})
|
||||
|
||||
@@ -163,10 +166,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
|
||||
fileID := uuid.New()
|
||||
|
||||
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
|
||||
err = tc.storage.SaveFile(
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
bytes.NewReader(fileData),
|
||||
)
|
||||
require.NoError(t, err, "SaveFile should succeed")
|
||||
|
||||
file, err := tc.storage.GetFile(fileID)
|
||||
file, err := tc.storage.GetFile(encryptor, fileID)
|
||||
assert.NoError(t, err, "GetFile should succeed")
|
||||
defer file.Close()
|
||||
|
||||
@@ -180,13 +188,18 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
require.NoError(t, err, "Should be able to read test file")
|
||||
|
||||
fileID := uuid.New()
|
||||
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
|
||||
err = tc.storage.SaveFile(
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
bytes.NewReader(fileData),
|
||||
)
|
||||
require.NoError(t, err, "SaveFile should succeed")
|
||||
|
||||
err = tc.storage.DeleteFile(fileID)
|
||||
err = tc.storage.DeleteFile(encryptor, fileID)
|
||||
assert.NoError(t, err, "DeleteFile should succeed")
|
||||
|
||||
file, err := tc.storage.GetFile(fileID)
|
||||
file, err := tc.storage.GetFile(encryptor, fileID)
|
||||
assert.Error(t, err, "GetFile should fail for non-existent file")
|
||||
if file != nil {
|
||||
file.Close()
|
||||
@@ -196,7 +209,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
|
||||
// Try to delete a non-existent file
|
||||
nonExistentID := uuid.New()
|
||||
err := tc.storage.DeleteFile(nonExistentID)
|
||||
err := tc.storage.DeleteFile(encryptor, nonExistentID)
|
||||
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
|
||||
})
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -37,8 +38,13 @@ func (s *AzureBlobStorage) TableName() string {
|
||||
return "azure_blob_storages"
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
|
||||
client, err := s.getClient()
|
||||
func (s *AzureBlobStorage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,8 +65,11 @@ func (s *AzureBlobStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
client, err := s.getClient()
|
||||
func (s *AzureBlobStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -80,8 +89,8 @@ func (s *AzureBlobStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
return response.Body, nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
client, err := s.getClient()
|
||||
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -105,7 +114,7 @@ func (s *AzureBlobStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) Validate() error {
|
||||
func (s *AzureBlobStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.ContainerName == "" {
|
||||
return errors.New("container name is required")
|
||||
}
|
||||
@@ -128,16 +137,11 @@ func (s *AzureBlobStorage) Validate() error {
|
||||
return fmt.Errorf("invalid auth method: %s", s.AuthMethod)
|
||||
}
|
||||
|
||||
_, err := s.getClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid Azure Blob configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) TestConnection() error {
|
||||
client, err := s.getClient()
|
||||
func (s *AzureBlobStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -192,6 +196,26 @@ func (s *AzureBlobStorage) HideSensitiveData() {
|
||||
s.AccountKey = ""
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
var err error
|
||||
|
||||
if s.ConnectionString != "" {
|
||||
s.ConnectionString, err = encryptor.Encrypt(s.StorageID, s.ConnectionString)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt Azure connection string: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.AccountKey != "" {
|
||||
s.AccountKey, err = encryptor.Encrypt(s.StorageID, s.AccountKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt Azure account key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) Update(incoming *AzureBlobStorage) {
|
||||
s.AuthMethod = incoming.AuthMethod
|
||||
s.ContainerName = incoming.ContainerName
|
||||
@@ -225,13 +249,18 @@ func (s *AzureBlobStorage) buildBlobName(fileName string) string {
|
||||
return prefix + fileName
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) getClient() (*azblob.Client, error) {
|
||||
func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azblob.Client, error) {
|
||||
var client *azblob.Client
|
||||
var err error
|
||||
|
||||
switch s.AuthMethod {
|
||||
case AuthMethodConnectionString:
|
||||
client, err = azblob.NewClientFromConnectionString(s.ConnectionString, nil)
|
||||
connectionString, decryptErr := encryptor.Decrypt(s.StorageID, s.ConnectionString)
|
||||
if decryptErr != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt Azure connection string: %w", decryptErr)
|
||||
}
|
||||
|
||||
client, err = azblob.NewClientFromConnectionString(connectionString, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to create Azure Blob client from connection string: %w",
|
||||
@@ -239,8 +268,13 @@ func (s *AzureBlobStorage) getClient() (*azblob.Client, error) {
|
||||
)
|
||||
}
|
||||
case AuthMethodAccountKey:
|
||||
accountKey, decryptErr := encryptor.Decrypt(s.StorageID, s.AccountKey)
|
||||
if decryptErr != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt Azure account key: %w", decryptErr)
|
||||
}
|
||||
|
||||
accountURL := s.buildAccountURL()
|
||||
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, s.AccountKey)
|
||||
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, accountKey)
|
||||
if credErr != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure shared key credential: %w", credErr)
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -30,11 +31,12 @@ func (s *GoogleDriveStorage) TableName() string {
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
return s.withRetryOnAuth(func(driveService *drive.Service) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
filename := fileID.String()
|
||||
|
||||
@@ -68,9 +70,12 @@ func (s *GoogleDriveStorage) SaveFile(
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
func (s *GoogleDriveStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
var result io.ReadCloser
|
||||
err := s.withRetryOnAuth(func(driveService *drive.Service) error {
|
||||
err := s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
folderID, err := s.findBackupsFolder(driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
@@ -93,8 +98,11 @@ func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
return s.withRetryOnAuth(func(driveService *drive.Service) error {
|
||||
func (s *GoogleDriveStorage) DeleteFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
folderID, err := s.findBackupsFolder(driveService)
|
||||
if err != nil {
|
||||
@@ -105,7 +113,7 @@ func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) Validate() error {
|
||||
func (s *GoogleDriveStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
switch {
|
||||
case s.ClientID == "":
|
||||
return errors.New("client ID is required")
|
||||
@@ -115,7 +123,12 @@ func (s *GoogleDriveStorage) Validate() error {
|
||||
return errors.New("token JSON is required")
|
||||
}
|
||||
|
||||
// Also validate that the token JSON contains a refresh token
|
||||
// Skip JSON validation if token is already encrypted
|
||||
if strings.HasPrefix(s.TokenJSON, "enc:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Validate that the token JSON contains a refresh token
|
||||
var token oauth2.Token
|
||||
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
|
||||
return fmt.Errorf("invalid token JSON format: %w", err)
|
||||
@@ -128,8 +141,8 @@ func (s *GoogleDriveStorage) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) TestConnection() error {
|
||||
return s.withRetryOnAuth(func(driveService *drive.Service) error {
|
||||
func (s *GoogleDriveStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
testFilename := "test-connection-" + uuid.New().String()
|
||||
testData := []byte("test")
|
||||
@@ -196,6 +209,26 @@ func (s *GoogleDriveStorage) HideSensitiveData() {
|
||||
s.TokenJSON = ""
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
var err error
|
||||
|
||||
if s.ClientSecret != "" {
|
||||
s.ClientSecret, err = encryptor.Encrypt(s.StorageID, s.ClientSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt Google Drive client secret: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.TokenJSON != "" {
|
||||
s.TokenJSON, err = encryptor.Encrypt(s.StorageID, s.TokenJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt Google Drive token JSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
|
||||
s.ClientID = incoming.ClientID
|
||||
|
||||
@@ -209,8 +242,11 @@ func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
|
||||
}
|
||||
|
||||
// withRetryOnAuth executes the provided function with retry logic for authentication errors
|
||||
func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) error {
|
||||
driveService, err := s.getDriveService()
|
||||
func (s *GoogleDriveStorage) withRetryOnAuth(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fn func(*drive.Service) error,
|
||||
) error {
|
||||
driveService, err := s.getDriveService(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -220,7 +256,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
|
||||
// Try to refresh token and retry once
|
||||
fmt.Printf("Google Drive auth error detected, attempting token refresh: %v\n", err)
|
||||
|
||||
if refreshErr := s.refreshToken(); refreshErr != nil {
|
||||
if refreshErr := s.refreshToken(encryptor); refreshErr != nil {
|
||||
// If refresh fails, return a more helpful error message
|
||||
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
|
||||
strings.Contains(refreshErr.Error(), "refresh token") {
|
||||
@@ -237,7 +273,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
|
||||
fmt.Printf("Token refresh successful, retrying operation\n")
|
||||
|
||||
// Get new service with refreshed token
|
||||
driveService, err = s.getDriveService()
|
||||
driveService, err = s.getDriveService(encryptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create service after token refresh: %w", err)
|
||||
}
|
||||
@@ -268,13 +304,24 @@ func (s *GoogleDriveStorage) isAuthError(err error) bool {
|
||||
}
|
||||
|
||||
// refreshToken refreshes the OAuth2 token and updates the TokenJSON field
|
||||
func (s *GoogleDriveStorage) refreshToken() error {
|
||||
if err := s.Validate(); err != nil {
|
||||
func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) error {
|
||||
if err := s.Validate(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decrypt credentials before use
|
||||
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
|
||||
}
|
||||
|
||||
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return fmt.Errorf("invalid token JSON: %w", err)
|
||||
}
|
||||
|
||||
@@ -289,12 +336,12 @@ func (s *GoogleDriveStorage) refreshToken() error {
|
||||
token.Expiry)
|
||||
|
||||
// Debug: Print the full token JSON structure (sensitive data masked)
|
||||
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(s.TokenJSON))
|
||||
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(tokenJSON))
|
||||
|
||||
ctx := context.Background()
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: google.Endpoint,
|
||||
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
|
||||
}
|
||||
@@ -330,7 +377,7 @@ func (s *GoogleDriveStorage) refreshToken() error {
|
||||
newToken.RefreshToken = token.RefreshToken
|
||||
}
|
||||
|
||||
// Update the stored token JSON
|
||||
// Update the stored token JSON (keep as plaintext in memory, encryption happens on save)
|
||||
newTokenJSON, err := json.Marshal(newToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal refreshed token: %w", err)
|
||||
@@ -368,13 +415,26 @@ func truncateString(s string, maxLen int) string {
|
||||
return s[:maxLen]
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
|
||||
if err := s.Validate(); err != nil {
|
||||
func (s *GoogleDriveStorage) getDriveService(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (*drive.Service, error) {
|
||||
if err := s.Validate(encryptor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt credentials before use
|
||||
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
|
||||
}
|
||||
|
||||
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
|
||||
}
|
||||
|
||||
var token oauth2.Token
|
||||
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
|
||||
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
|
||||
return nil, fmt.Errorf("invalid token JSON: %w", err)
|
||||
}
|
||||
|
||||
@@ -382,7 +442,7 @@ func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
|
||||
|
||||
cfg := &oauth2.Config{
|
||||
ClientID: s.ClientID,
|
||||
ClientSecret: s.ClientSecret,
|
||||
ClientSecret: clientSecret,
|
||||
Endpoint: google.Endpoint,
|
||||
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,7 +24,12 @@ func (l *LocalStorage) TableName() string {
|
||||
return "local_storages"
|
||||
}
|
||||
|
||||
func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
|
||||
func (l *LocalStorage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
|
||||
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
@@ -107,7 +113,10 @@ func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.R
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
func (l *LocalStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
@@ -122,7 +131,7 @@ func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
return file, nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
|
||||
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
@@ -136,11 +145,11 @@ func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) Validate() error {
|
||||
func (l *LocalStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) TestConnection() error {
|
||||
func (l *LocalStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
testFile := filepath.Join(config.GetEnv().TempFolder, "test_connection")
|
||||
f, err := os.Create(testFile)
|
||||
if err != nil {
|
||||
@@ -160,5 +169,9 @@ func (l *LocalStorage) TestConnection() error {
|
||||
func (l *LocalStorage) HideSensitiveData() {
|
||||
}
|
||||
|
||||
func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalStorage) Update(incoming *LocalStorage) {
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -31,10 +32,15 @@ func (n *NASStorage) TableName() string {
|
||||
return "nas_storages"
|
||||
}
|
||||
|
||||
func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
|
||||
func (n *NASStorage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
|
||||
|
||||
session, err := n.createSession()
|
||||
session, err := n.createSession(encryptor)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to create NAS session: %w", err)
|
||||
@@ -131,8 +137,11 @@ func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Rea
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
session, err := n.createSession()
|
||||
func (n *NASStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
session, err := n.createSession(encryptor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create NAS session: %w", err)
|
||||
}
|
||||
@@ -168,8 +177,8 @@ func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
session, err := n.createSession()
|
||||
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
session, err := n.createSession(encryptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create NAS session: %w", err)
|
||||
}
|
||||
@@ -202,7 +211,7 @@ func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) Validate() error {
|
||||
func (n *NASStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.Host == "" {
|
||||
return errors.New("NAS host is required")
|
||||
}
|
||||
@@ -219,12 +228,11 @@ func (n *NASStorage) Validate() error {
|
||||
return errors.New("NAS port must be between 1 and 65535")
|
||||
}
|
||||
|
||||
// Test the configuration by creating a session
|
||||
return n.TestConnection()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) TestConnection() error {
|
||||
session, err := n.createSession()
|
||||
func (n *NASStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
session, err := n.createSession(encryptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to NAS: %w", err)
|
||||
}
|
||||
@@ -255,6 +263,18 @@ func (n *NASStorage) HideSensitiveData() {
|
||||
n.Password = ""
|
||||
}
|
||||
|
||||
func (n *NASStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if n.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(n.StorageID, n.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt NAS password: %w", err)
|
||||
}
|
||||
n.Password = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) Update(incoming *NASStorage) {
|
||||
n.Host = incoming.Host
|
||||
n.Port = incoming.Port
|
||||
@@ -269,18 +289,25 @@ func (n *NASStorage) Update(incoming *NASStorage) {
|
||||
}
|
||||
}
|
||||
|
||||
func (n *NASStorage) createSession() (*smb2.Session, error) {
|
||||
func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.Session, error) {
|
||||
// Create connection with timeout
|
||||
conn, err := n.createConnection()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password before use
|
||||
password, err := encryptor.Decrypt(n.StorageID, n.Password)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("failed to decrypt NAS password: %w", err)
|
||||
}
|
||||
|
||||
// Create SMB2 dialer
|
||||
d := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: n.Username,
|
||||
Password: n.Password,
|
||||
Password: password,
|
||||
Domain: n.Domain,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -31,8 +32,13 @@ func (s *S3Storage) TableName() string {
|
||||
return "s3_storages"
|
||||
}
|
||||
|
||||
func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
|
||||
client, err := s.getClient()
|
||||
func (s *S3Storage) SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -55,8 +61,11 @@ func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Read
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
client, err := s.getClient()
|
||||
func (s *S3Storage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -91,8 +100,8 @@ func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
|
||||
return object, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
|
||||
client, err := s.getClient()
|
||||
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -113,7 +122,7 @@ func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Validate() error {
|
||||
func (s *S3Storage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.S3Bucket == "" {
|
||||
return errors.New("S3 bucket is required")
|
||||
}
|
||||
@@ -124,17 +133,11 @@ func (s *S3Storage) Validate() error {
|
||||
return errors.New("S3 secret key is required")
|
||||
}
|
||||
|
||||
// Try to create a client to validate the configuration
|
||||
_, err := s.getClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid S3 configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) TestConnection() error {
|
||||
client, err := s.getClient()
|
||||
func (s *S3Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -195,6 +198,26 @@ func (s *S3Storage) HideSensitiveData() {
|
||||
s.S3SecretKey = ""
|
||||
}
|
||||
|
||||
func (s *S3Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
var err error
|
||||
|
||||
if s.S3AccessKey != "" {
|
||||
s.S3AccessKey, err = encryptor.Encrypt(s.StorageID, s.S3AccessKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt S3 access key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.S3SecretKey != "" {
|
||||
s.S3SecretKey, err = encryptor.Encrypt(s.StorageID, s.S3SecretKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt S3 secret key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) Update(incoming *S3Storage) {
|
||||
s.S3Bucket = incoming.S3Bucket
|
||||
s.S3Region = incoming.S3Region
|
||||
@@ -228,7 +251,7 @@ func (s *S3Storage) buildObjectKey(fileName string) string {
|
||||
return prefix + fileName
|
||||
}
|
||||
|
||||
func (s *S3Storage) getClient() (*minio.Client, error) {
|
||||
func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Client, error) {
|
||||
endpoint := s.S3Endpoint
|
||||
useSSL := true
|
||||
|
||||
@@ -244,6 +267,17 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", s.S3Region)
|
||||
}
|
||||
|
||||
// Decrypt credentials before use
|
||||
accessKey, err := encryptor.Decrypt(s.StorageID, s.S3AccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
|
||||
}
|
||||
|
||||
secretKey, err := encryptor.Decrypt(s.StorageID, s.S3SecretKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
|
||||
}
|
||||
|
||||
// Configure bucket lookup strategy
|
||||
bucketLookup := minio.BucketLookupAuto
|
||||
if s.S3UseVirtualHostedStyle {
|
||||
@@ -252,7 +286,7 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
|
||||
|
||||
// Initialize the MinIO client
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(s.S3AccessKey, s.S3SecretKey, ""),
|
||||
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
|
||||
Secure: useSSL,
|
||||
Region: s.S3Region,
|
||||
BucketLookup: bucketLookup,
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -15,6 +16,7 @@ type StorageService struct {
|
||||
storageRepository *StorageRepository
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *StorageService) SaveStorage(
|
||||
@@ -44,7 +46,11 @@ func (s *StorageService) SaveStorage(
|
||||
|
||||
existingStorage.Update(storage)
|
||||
|
||||
if err := existingStorage.Validate(); err != nil {
|
||||
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -61,7 +67,11 @@ func (s *StorageService) SaveStorage(
|
||||
} else {
|
||||
storage.WorkspaceID = workspaceID
|
||||
|
||||
if err := storage.Validate(); err != nil {
|
||||
if err := storage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := storage.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -174,7 +184,7 @@ func (s *StorageService) TestStorageConnection(
|
||||
return errors.New("insufficient permissions to test storage in this workspace")
|
||||
}
|
||||
|
||||
err = storage.TestConnection()
|
||||
err = storage.TestConnection(s.fieldEncryptor)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
storage.LastSaveError = &lastSaveError
|
||||
@@ -207,7 +217,7 @@ func (s *StorageService) TestStorageConnectionDirect(
|
||||
|
||||
existingStorage.Update(storage)
|
||||
|
||||
if err := existingStorage.Validate(); err != nil {
|
||||
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -216,7 +226,7 @@ func (s *StorageService) TestStorageConnectionDirect(
|
||||
usingStorage = storage
|
||||
}
|
||||
|
||||
return usingStorage.TestConnection()
|
||||
return usingStorage.TestConnection(s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *StorageService) GetStorageByID(
|
||||
|
||||
@@ -309,15 +309,6 @@ func (s *UserService) ChangeUserPasswordByEmail(email string, newPassword string
|
||||
}
|
||||
|
||||
func (s *UserService) ChangeUserPassword(userID uuid.UUID, newPassword string) error {
|
||||
user, err := s.userRepository.GetUserByID(userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user: %w", err)
|
||||
}
|
||||
|
||||
if !user.HasPassword() {
|
||||
return errors.New("user has no password set")
|
||||
}
|
||||
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash new password: %w", err)
|
||||
|
||||
11
backend/internal/util/encryption/di.go
Normal file
11
backend/internal/util/encryption/di.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package encryption
|
||||
|
||||
import users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
|
||||
var fieldEncryptor = &SecretKeyFieldEncryptor{
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
}
|
||||
|
||||
func GetFieldEncryptor() FieldEncryptor {
|
||||
return fieldEncryptor
|
||||
}
|
||||
15
backend/internal/util/encryption/field_encryptor.go
Normal file
15
backend/internal/util/encryption/field_encryptor.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package encryption
|
||||
|
||||
import "github.com/google/uuid"
|
||||
|
||||
type FieldEncryptor interface {
|
||||
// Encrypt encrypts a plaintext string and returns an encrypted string.
|
||||
// If the string is already encrypted, returns it as-is.
|
||||
// Empty strings are returned unchanged.
|
||||
Encrypt(itemID uuid.UUID, plaintext string) (string, error)
|
||||
|
||||
// Decrypt decrypts an encrypted string and returns a plaintext string.
|
||||
// If the string is not encrypted, returns it as-is.
|
||||
// Empty strings are returned unchanged.
|
||||
Decrypt(itemID uuid.UUID, ciphertext string) (string, error)
|
||||
}
|
||||
121
backend/internal/util/encryption/secret_key_field_encryptor.go
Normal file
121
backend/internal/util/encryption/secret_key_field_encryptor.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const encryptedPrefix = "enc:"
|
||||
|
||||
type SecretKeyFieldEncryptor struct {
|
||||
secretKeyRepository *users_repositories.SecretKeyRepository
|
||||
}
|
||||
|
||||
func (e *SecretKeyFieldEncryptor) Encrypt(itemID uuid.UUID, plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if e.isEncrypted(plaintext) {
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
masterKey, err := e.secretKeyRepository.GetSecretKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher([]byte(masterKey)[:32])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
nonce := e.deriveNonce(itemID, masterKey, gcm.NonceSize())
|
||||
|
||||
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
|
||||
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
ciphertextBase64 := base64.StdEncoding.EncodeToString(ciphertext)
|
||||
|
||||
return fmt.Sprintf("%s%s:%s", encryptedPrefix, nonceBase64, ciphertextBase64), nil
|
||||
}
|
||||
|
||||
func (e *SecretKeyFieldEncryptor) Decrypt(itemID uuid.UUID, ciphertext string) (string, error) {
|
||||
if ciphertext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
if !e.isEncrypted(ciphertext) {
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
parts := strings.SplitN(ciphertext, ":", 3)
|
||||
if len(parts) != 3 {
|
||||
return "", errors.New("invalid encrypted format")
|
||||
}
|
||||
|
||||
nonceBase64 := parts[1]
|
||||
ciphertextBase64 := parts[2]
|
||||
|
||||
nonce, err := base64.StdEncoding.DecodeString(nonceBase64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode nonce: %w", err)
|
||||
}
|
||||
|
||||
encryptedData, err := base64.StdEncoding.DecodeString(ciphertextBase64)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := e.secretKeyRepository.GetSecretKey()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher([]byte(masterKey)[:32])
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
plaintext, err := gcm.Open(nil, nonce, encryptedData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
|
||||
func (e *SecretKeyFieldEncryptor) isEncrypted(value string) bool {
|
||||
return strings.HasPrefix(value, encryptedPrefix)
|
||||
}
|
||||
|
||||
func (e *SecretKeyFieldEncryptor) deriveNonce(
|
||||
itemID uuid.UUID,
|
||||
masterKey string,
|
||||
nonceSize int,
|
||||
) []byte {
|
||||
h := hmac.New(sha256.New, []byte(masterKey))
|
||||
h.Write(itemID[:])
|
||||
hash := h.Sum(nil)
|
||||
return hash[:nonceSize]
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Encrypt_Decrypt_RoundTrip(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
plaintext := "my-secret-password"
|
||||
|
||||
encrypted, err := encryptor.Encrypt(itemID, plaintext)
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, encrypted)
|
||||
assert.NotEqual(t, plaintext, encrypted)
|
||||
assert.Contains(t, encrypted, "enc:")
|
||||
|
||||
decrypted, err := encryptor.Decrypt(itemID, encrypted)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func Test_Encrypt_EmptyString_ReturnsEmpty(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
|
||||
encrypted, err := encryptor.Encrypt(itemID, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, encrypted)
|
||||
}
|
||||
|
||||
func Test_Decrypt_EmptyString_ReturnsEmpty(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
|
||||
decrypted, err := encryptor.Decrypt(itemID, "")
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, decrypted)
|
||||
}
|
||||
|
||||
func Test_Decrypt_PlaintextValue_ReturnsAsIs(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
plaintext := "not-encrypted-password"
|
||||
|
||||
decrypted, err := encryptor.Decrypt(itemID, plaintext)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted)
|
||||
}
|
||||
|
||||
func Test_Encrypt_DetectsAlreadyEncryptedFormat(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
alreadyEncrypted := "enc:nonce:ciphertext"
|
||||
|
||||
result, err := encryptor.Encrypt(itemID, alreadyEncrypted)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, alreadyEncrypted, result)
|
||||
}
|
||||
|
||||
func Test_Encrypt_SamePlaintext_DifferentItemIDs_ProducesDifferentCiphertext(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
plaintext := "shared-secret"
|
||||
itemID1 := uuid.New()
|
||||
itemID2 := uuid.New()
|
||||
|
||||
encrypted1, err := encryptor.Encrypt(itemID1, plaintext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encrypted2, err := encryptor.Encrypt(itemID2, plaintext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, encrypted1, encrypted2)
|
||||
|
||||
decrypted1, err := encryptor.Decrypt(itemID1, encrypted1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted1)
|
||||
|
||||
decrypted2, err := encryptor.Decrypt(itemID2, encrypted2)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, plaintext, decrypted2)
|
||||
}
|
||||
|
||||
func Test_Encrypt_AlreadyEncrypted_ReturnsAsIs(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
plaintext := "my-password"
|
||||
|
||||
encrypted1, err := encryptor.Encrypt(itemID, plaintext)
|
||||
assert.NoError(t, err)
|
||||
|
||||
encrypted2, err := encryptor.Encrypt(itemID, encrypted1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, encrypted1, encrypted2)
|
||||
}
|
||||
|
||||
func Test_Decrypt_MalformedData_ReturnsError(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
|
||||
_, err := encryptor.Decrypt(itemID, "enc:invalid")
|
||||
assert.Error(t, err)
|
||||
|
||||
_, err = encryptor.Decrypt(itemID, "enc:invalid:invalid-base64")
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func Test_EncryptedFormat_ContainsPrefix(t *testing.T) {
|
||||
encryptor := GetFieldEncryptor()
|
||||
itemID := uuid.New()
|
||||
plaintext := "test-secret"
|
||||
|
||||
encrypted, err := encryptor.Encrypt(itemID, plaintext)
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, encrypted, "enc:")
|
||||
}
|
||||
Reference in New Issue
Block a user