Compare commits

...

2 Commits

42 changed files with 1672 additions and 502 deletions

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/config"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/period"
"time"
)
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
continue
}
err = storage.DeleteFile(backup.ID)
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}

View File

@@ -26,6 +26,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -700,7 +701,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -9,6 +9,7 @@ import (
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"time"
)
@@ -25,6 +26,7 @@ var backupService = &BackupService{
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},

View File

@@ -16,6 +16,7 @@ import (
users_models "postgresus-backend/internal/features/users/models"
users_repositories "postgresus-backend/internal/features/users/repositories"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
util_encryption "postgresus-backend/internal/util/encryption"
"slices"
"strings"
"time"
@@ -31,6 +32,7 @@ type BackupService struct {
notificationSender NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyRepo *users_repositories.SecretKeyRepository
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
@@ -284,7 +286,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
@@ -545,7 +547,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
return err
}
err = storage.DeleteFile(backup.ID)
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
@@ -599,7 +601,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(backup.ID)
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}

View File

@@ -13,6 +13,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strings"
"testing"
@@ -56,11 +57,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}
@@ -103,11 +105,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}
@@ -127,11 +130,12 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}

View File

@@ -15,12 +15,13 @@ import (
"time"
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups/encryption"
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"github.com/google/uuid"
@@ -40,8 +41,9 @@ const (
)
type CreatePostgresqlBackupUsecase struct {
logger *slog.Logger
secretKeyRepo *users_repositories.SecretKeyRepository
logger *slog.Logger
secretKeyRepo *users_repositories.SecretKeyRepository
fieldEncryptor encryption.FieldEncryptor
}
// Execute creates a backup of the database
@@ -166,7 +168,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErrCh <- saveErr
}()
@@ -440,7 +442,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *encryption.EncryptionWriter, BackupMetadata, error) {
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
metadata := BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
@@ -449,12 +451,12 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
return storageWriter, nil, metadata, nil
}
salt, err := encryption.GenerateSalt()
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := encryption.GenerateNonce()
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
@@ -464,7 +466,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := encryption.NewEncryptionWriter(
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
@@ -486,7 +488,7 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
encryptionWriter *encryption.EncryptionWriter,
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
saveErrCh chan error,
) {
@@ -510,7 +512,7 @@ func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
}
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
encryptionWriter *encryption.EncryptionWriter,
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
) error {
encryptionCloseErrCh := make(chan error, 1)

View File

@@ -2,12 +2,14 @@ package usecases_postgresql
import (
users_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
logger.GetLogger(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
}
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {

View File

@@ -16,6 +16,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -769,6 +770,71 @@ func createTestDatabaseViaAPI(
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
},
}
var createdDatabase Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&createdDatabase,
)
repository := &DatabaseRepository{}
databaseFromDB, err := repository.FindByID(createdDatabase.ID)
assert.NoError(t, err)
assert.NotNil(t, databaseFromDB)
assert.NotNil(t, databaseFromDB.Postgresql)
assert.True(
t,
strings.HasPrefix(databaseFromDB.Postgresql.Password, "enc:"),
"Password should be encrypted in database with 'enc:' prefix, got: %s",
databaseFromDB.Postgresql.Password,
)
encryptor := encryption.GetFieldEncryptor()
decryptedPassword, err := encryptor.Decrypt(
databaseFromDB.ID,
databaseFromDB.Postgresql.Password,
)
assert.NoError(t, err)
assert.Equal(t, plainPassword, decryptedPassword,
"Decrypted password should match original plaintext password")
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+createdDatabase.ID.String(),
"Bearer "+owner.Token,
http.StatusNoContent,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -815,7 +881,15 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
assert.Equal(t, "original-password-secret", database.Postgresql.Password)
// Verify password is encrypted
assert.True(t, strings.HasPrefix(database.Postgresql.Password, "enc:"),
"Password should be encrypted in database")
// Verify it can be decrypted back to original
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)

View File

@@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"regexp"
"slices"
"time"
"github.com/google/uuid"
@@ -59,11 +59,15 @@ func (p *PostgresqlDatabase) Validate() error {
return nil
}
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
func (p *PostgresqlDatabase) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return testSingleDatabaseConnection(logger, ctx, p)
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
}
func (p *PostgresqlDatabase) HideSensitiveData() {
@@ -87,19 +91,42 @@ func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
}
}
func (p *PostgresqlDatabase) EncryptSensitiveFields(
databaseID uuid.UUID,
encryptor encryption.FieldEncryptor,
) error {
if p.Password != "" {
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
if err != nil {
return err
}
p.Password = encrypted
}
return nil
}
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
func testSingleDatabaseConnection(
logger *slog.Logger,
ctx context.Context,
postgresDb *PostgresqlDatabase,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
// For single database backup, we need to connect to the specific database
if postgresDb.Database == nil || *postgresDb.Database == "" {
return errors.New("database name is required for single database backup (pg_dump)")
}
// Decrypt password if needed
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
// Build connection string for the specific database
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
// Test connection
conn, err := pgx.Connect(ctx, connStr)
@@ -182,7 +209,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
}
// buildConnectionStringForDB builds connection string for specific database
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
sslMode := "disable"
if p.IsHttps {
sslMode = "require"
@@ -192,106 +219,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
p.Host,
p.Port,
p.Username,
p.Password,
password,
dbName,
sslMode,
)
}
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
if len(extensions) == 0 {
return nil
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, error) {
if encryptor == nil {
return password, nil
}
if p.Database == nil || *p.Database == "" {
return errors.New("database name is required for installing extensions")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Build connection string for the specific database
connStr := buildConnectionStringForDB(p, *p.Database)
// Connect to database
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
fmt.Println("failed to close connection: %w", closeErr)
}
}()
// Check which extensions are already installed
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
if err != nil {
return fmt.Errorf("failed to check installed extensions: %w", err)
}
// Install missing extensions
for _, extension := range extensions {
if contains(installedExtensions, string(extension)) {
continue // Extension already installed
}
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
}
}
return nil
}
// getInstalledExtensions queries the database for currently installed extensions
func (p *PostgresqlDatabase) getInstalledExtensions(
ctx context.Context,
conn *pgx.Conn,
) ([]string, error) {
query := "SELECT extname FROM pg_extension"
rows, err := conn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
}
defer rows.Close()
var extensions []string
for rows.Next() {
var extname string
if err := rows.Scan(&extname); err != nil {
return nil, fmt.Errorf("failed to scan extension name: %w", err)
}
extensions = append(extensions, extname)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
}
return extensions, nil
}
// installExtension installs a single PostgreSQL extension
func (p *PostgresqlDatabase) installExtension(
ctx context.Context,
conn *pgx.Conn,
extensionName string,
) error {
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
_, err := conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
}
return nil
}
// contains checks if a string slice contains a specific string
func contains(slice []string, item string) bool {
return slices.Contains(slice, item)
return encryptor.Decrypt(databaseID, password)
}

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -19,6 +20,7 @@ var databaseService = &DatabaseService{
[]DatabaseCopyListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var databaseController = &DatabaseController{

View File

@@ -2,6 +2,7 @@ package databases
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -11,7 +12,11 @@ type DatabaseValidator interface {
}
type DatabaseConnector interface {
TestConnection(logger *slog.Logger) error
TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error
HideSensitiveData()
}

View File

@@ -5,6 +5,7 @@ import (
"log/slog"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -56,14 +57,24 @@ func (d *Database) ValidateUpdate(old, new Database) error {
return nil
}
func (d *Database) TestConnection(logger *slog.Logger) error {
return d.getSpecificDatabase().TestConnection(logger)
func (d *Database) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
if d.Postgresql != nil {
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
}
return nil
}
func (d *Database) Update(incoming *Database) {
d.Name = incoming.Name
d.Type = incoming.Type

View File

@@ -11,6 +11,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -26,6 +27,7 @@ type DatabaseService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *DatabaseService) AddDbCreationListener(
@@ -65,6 +67,10 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
database, err = s.dbRepository.Save(database)
if err != nil {
return nil, err
@@ -118,6 +124,10 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
_, err = s.dbRepository.Save(existingDatabase)
if err != nil {
return err
@@ -250,7 +260,7 @@ func (s *DatabaseService) TestDatabaseConnection(
return errors.New("insufficient permissions to test connection for this database")
}
err = database.TestConnection(s.logger)
err = database.TestConnection(s.logger, s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
database.LastBackupErrorMessage = &lastSaveError
@@ -294,7 +304,7 @@ func (s *DatabaseService) TestDatabaseConnectionDirect(
usingDatabase = database
}
return usingDatabase.TestConnection(s.logger)
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) GetDatabaseByID(

View File

@@ -453,70 +453,6 @@ func Test_CrossWorkspaceSecurity_CannotAccessNotifierFromAnotherWorkspace(t *tes
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -553,7 +489,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-bot-token-12345", notifier.TelegramNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "original-bot-token-12345", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.TelegramNotifier.BotToken)
@@ -592,7 +534,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-password-secret", notifier.EmailNotifier.SMTPPassword)
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.EmailNotifier.SMTPPassword)
@@ -625,7 +573,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "xoxb-original-slack-token", notifier.SlackNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "xoxb-original-slack-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.SlackNotifier.BotToken)
@@ -656,11 +610,17 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(
assert.True(
t,
"https://discord.com/api/webhooks/123/original-token",
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/original-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.DiscordNotifier.ChannelWebhookURL)
@@ -691,10 +651,16 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(
t,
"https://outlook.office.com/webhook/original-token",
notifier.TeamsNotifier.WebhookURL,
decrypted,
)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
@@ -813,3 +779,263 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
testCases := []struct {
name string
createNotifier func(workspaceID uuid.UUID) *Notifier
verifySensitiveEncryption func(t *testing.T, notifier *Notifier)
}{
{
name: "Telegram Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram",
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: "plain-telegram-token-123",
TargetChatID: "123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "plain-telegram-token-123", decrypted)
},
},
{
name: "Email Notifier - SMTPPassword encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Email",
NotifierType: NotifierTypeEmail,
EmailNotifier: &email_notifier.EmailNotifier{
TargetEmail: "test@example.com",
SMTPHost: "smtp.example.com",
SMTPPort: 587,
SMTPUser: "user@example.com",
SMTPPassword: "plain-smtp-password-456",
From: "noreply@example.com",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "plain-smtp-password-456", decrypted)
},
},
{
name: "Slack Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Slack",
NotifierType: NotifierTypeSlack,
SlackNotifier: &slack_notifier.SlackNotifier{
BotToken: "plain-slack-token-789",
TargetChatID: "C0123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "plain-slack-token-789", decrypted)
},
},
{
name: "Discord Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Discord",
NotifierType: NotifierTypeDiscord,
DiscordNotifier: &discord_notifier.DiscordNotifier{
ChannelWebhookURL: "https://discord.com/api/webhooks/123/abc",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/abc", decrypted)
},
},
{
name: "Teams Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Teams",
NotifierType: NotifierTypeTeams,
TeamsNotifier: &teams_notifier.TeamsNotifier{
WebhookURL: "https://outlook.office.com/webhook/test123",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(t, "https://outlook.office.com/webhook/test123", decrypted)
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Webhook",
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Create notifier via API (plaintext credentials)
var createdNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
tc.createNotifier(workspace.ID),
http.StatusOK,
&createdNotifier,
)
// Read from DB directly (bypass service layer)
repository := &NotifierRepository{}
notifierFromDB, err := repository.FindByID(createdNotifier.ID)
assert.NoError(t, err)
// Verify encryption
tc.verifySensitiveEncryption(t, notifierFromDB)
// Cleanup
deleteNotifier(t, router, createdNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func isEncrypted(value string) bool {
return len(value) > 4 && value[:4] == "enc:"
}
func decryptField(t *testing.T, notifierID uuid.UUID, encryptedValue string) string {
encryptor := GetNotifierService().fieldEncryptor
decrypted, err := encryptor.Decrypt(notifierID, encryptedValue)
assert.NoError(t, err)
return decrypted
}

View File

@@ -3,6 +3,7 @@ package notifiers
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -12,6 +13,7 @@ var notifierService = &NotifierService{
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var notifierController = &NotifierController{
notifierService,

View File

@@ -1,11 +1,21 @@
package notifiers
import "log/slog"
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
)
type NotificationSender interface {
Send(logger *slog.Logger, heading string, message string) error
Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -9,6 +9,7 @@ import (
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
return "notifiers"
}
func (n *Notifier) Validate() error {
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.Name == "" {
return errors.New("name is required")
}
return n.getSpecificNotifier().Validate()
return n.getSpecificNotifier().Validate(encryptor)
}
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
err := n.getSpecificNotifier().Send(logger, heading, message)
func (n *Notifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
if err != nil {
lastSendError := err.Error()
@@ -58,6 +64,10 @@ func (n *Notifier) HideSensitiveData() {
n.getSpecificNotifier().HideSensitiveData()
}
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
}
func (n *Notifier) Update(incoming *Notifier) {
n.Name = incoming.Name
n.NotifierType = incoming.NotifierType

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
return "discord_notifiers"
}
func (d *DiscordNotifier) Validate() error {
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
return nil
}
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (d *DiscordNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
@@ -81,3 +92,14 @@ func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
d.ChannelWebhookURL = incoming.ChannelWebhookURL
}
}
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL != "" {
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
d.ChannelWebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net"
"net/smtp"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -34,7 +35,7 @@ func (e *EmailNotifier) TableName() string {
return "email_notifiers"
}
func (e *EmailNotifier) Validate() error {
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if e.TargetEmail == "" {
return errors.New("target email is required")
}
@@ -55,7 +56,22 @@ func (e *EmailNotifier) Validate() error {
return nil
}
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (e *EmailNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
// Decrypt SMTP password if provided
var smtpPassword string
if e.SMTPPassword != "" {
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
}
smtpPassword = decrypted
}
// Compose email
from := e.From
if from == "" {
@@ -85,7 +101,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
timeout := DefaultTimeout
// Determine if authentication is required
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
// Handle different port scenarios
if e.SMTPPort == ImplicitTLSPort {
@@ -116,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Set up authentication only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -179,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Authenticate only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -229,3 +245,14 @@ func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.SMTPPassword = incoming.SMTPPassword
}
}
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if e.SMTPPassword != "" {
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
}
e.SMTPPassword = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
"time"
@@ -23,7 +24,7 @@ type SlackNotifier struct {
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
func (s *SlackNotifier) Validate() error {
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if s.BotToken == "" {
return errors.New("bot token is required")
}
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
return nil
}
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
func (s *SlackNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
full := fmt.Sprintf("*%s*", heading)
if message != "" {
@@ -80,7 +90,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Authorization", "Bearer "+s.BotToken)
req.Header.Set("Authorization", "Bearer "+botToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@@ -144,3 +154,14 @@ func (s *SlackNotifier) Update(incoming *SlackNotifier) {
s.BotToken = incoming.BotToken
}
}
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if s.BotToken != "" {
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
s.BotToken = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
return "teams_notifiers"
}
func (n *TeamsNotifier) Validate() error {
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL == "" {
return errors.New("webhook_url is required")
}
u, err := url.Parse(n.WebhookURL)
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
u, err := url.Parse(webhookURL)
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
return errors.New("invalid webhook_url")
}
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
}
type cardAttachment struct {
ContentType string `json:"contentType"`
Content interface{} `json:"content"`
ContentType string `json:"contentType"`
Content any `json:"content"`
}
type payload struct {
@@ -43,11 +50,20 @@ type payload struct {
Attachments []cardAttachment `json:"attachments,omitempty"`
}
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
if err := n.Validate(); err != nil {
func (n *TeamsNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
if err := n.Validate(encryptor); err != nil {
return err
}
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
card := map[string]any{
"type": "AdaptiveCard",
"version": "1.4",
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
}
body, _ := json.Marshal(p)
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return err
}
@@ -104,3 +120,14 @@ func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
n.WebhookURL = incoming.WebhookURL
}
}
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
n.WebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
return "telegram_notifiers"
}
func (t *TelegramNotifier) Validate() error {
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.BotToken == "" {
return errors.New("bot token is required")
}
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
return nil
}
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *TelegramNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
}
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
data := url.Values{}
data.Set("chat_id", t.TargetChatID)
@@ -93,3 +104,14 @@ func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
t.BotToken = incoming.BotToken
}
}
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.BotToken != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
t.BotToken = encrypted
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
return "webhook_notifiers"
}
func (t *WebhookNotifier) Validate() error {
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
return nil
}
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *WebhookNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
switch t.WebhookMethod {
case WebhookMethodGET:
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
t.WebhookURL,
webhookURL,
url.QueryEscape(heading),
url.QueryEscape(message),
)
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to send POST webhook: %w", err)
}
@@ -110,3 +121,14 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
t.WebhookURL = incoming.WebhookURL
t.WebhookMethod = incoming.WebhookMethod
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
t.WebhookURL = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -17,6 +18,7 @@ type NotifierService struct {
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *NotifierService) SaveNotifier(
@@ -46,7 +48,11 @@ func (s *NotifierService) SaveNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -63,7 +69,11 @@ func (s *NotifierService) SaveNotifier(
} else {
notifier.WorkspaceID = workspaceID
if err := notifier.Validate(); err != nil {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := notifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -175,7 +185,7 @@ func (s *NotifierService) SendTestNotification(
return errors.New("insufficient permissions to test notifier in this workspace")
}
err = notifier.Send(s.logger, "Test message", "This is a test message")
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
if err != nil {
return err
}
@@ -205,16 +215,24 @@ func (s *NotifierService) SendTestNotificationToNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = existingNotifier
} else {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = notifier
}
return usingNotifier.Send(s.logger, "Test message", "This is a test message")
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
}
func (s *NotifierService) SendNotification(
@@ -233,7 +251,7 @@ func (s *NotifierService) SendNotification(
return
}
err = notifiedFromDb.Send(s.logger, title, message)
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
if err != nil {
errMsg := err.Error()
notifiedFromDb.LastSendError = &errMsg

View File

@@ -29,6 +29,7 @@ import (
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
util_encryption "postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -309,6 +310,7 @@ func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -338,7 +340,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -23,6 +23,7 @@ import (
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
util_encryption "postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/tools"
@@ -209,7 +210,8 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
rawReader, err := storage.GetFile(backup.ID)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)

View File

@@ -3,10 +3,14 @@ package storages
import (
"fmt"
"net/http"
"strings"
"testing"
audit_logs "postgresus-backend/internal/features/audit_logs"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
@@ -14,6 +18,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
@@ -438,6 +443,386 @@ func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *test
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
"S3AccessKey should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
"S3SecretKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
assert.NoError(t, err)
assert.Equal(t, "original-access-key", accessKey)
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
assert.NoError(t, err)
assert.Equal(t, "original-secret-key", secretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
{
name: "NAS Storage",
storageType: StorageTypeNAS,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Test NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas.example.com",
Port: 445,
Share: "backups",
Username: "testuser",
Password: "original-password",
UseSSL: false,
Domain: "WORKGROUP",
Path: "/test",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Updated NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas2.example.com",
Port: 445,
Share: "backups2",
Username: "testuser2",
Password: "",
UseSSL: true,
Domain: "WORKGROUP2",
Path: "/test2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.NASStorage.Password)
},
},
{
name: "Azure Blob Storage (Connection String)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "original-connection-string",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
"ConnectionString should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
connectionString, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.ConnectionString,
)
assert.NoError(t, err)
assert.Equal(t, "original-connection-string", connectionString)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Azure Blob Storage (Account Key)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "testaccount",
AccountKey: "original-account-key",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "updatedaccount",
AccountKey: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
"AccountKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accountKey, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.AccountKey,
)
assert.NoError(t, err)
assert.Equal(t, "original-account-key", accountKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Google Drive Storage",
storageType: StorageTypeGoogleDrive,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Test Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Updated Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "updated-client-id",
ClientSecret: "",
TokenJSON: "",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
"ClientSecret should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
"TokenJSON should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
clientSecret, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.ClientSecret,
)
assert.NoError(t, err)
assert.Equal(t, "original-client-secret", clientSecret)
tokenJSON, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.TokenJSON,
)
assert.NoError(t, err)
assert.Equal(
t,
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
tokenJSON,
)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Verify sensitive data is encrypted in repository after creation
repository := &StorageRepository{}
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
tc.verifySensitiveData(t, storageFromDBAfterCreate)
// Phase 3: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -485,158 +870,3 @@ func deleteStorage(
http.StatusOK,
)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "original-access-key", storage.S3Storage.S3AccessKey)
assert.Equal(t, "original-secret-key", storage.S3Storage.S3SecretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &StorageRepository{}
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}

View File

@@ -3,6 +3,7 @@ package storages
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
)
var storageRepository = &StorageRepository{}
@@ -10,6 +11,7 @@ var storageService = &StorageService{
storageRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var storageController = &StorageController{
storageService,

View File

@@ -3,20 +3,28 @@ package storages
import (
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
type StorageFileSaver interface {
SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error
SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error
GetFile(fileID uuid.UUID) (io.ReadCloser, error)
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
DeleteFile(fileID uuid.UUID) error
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
TestConnection() error
TestConnection(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -9,6 +9,7 @@ import (
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -28,8 +29,13 @@ type Storage struct {
AzureBlobStorage *azure_blob_storage.AzureBlobStorage `json:"azureBlobStorage" gorm:"foreignKey:StorageID"`
}
func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
err := s.getSpecificStorage().SaveFile(logger, fileID, file)
func (s *Storage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
err := s.getSpecificStorage().SaveFile(encryptor, logger, fileID, file)
if err != nil {
lastSaveError := err.Error()
s.LastSaveError = &lastSaveError
@@ -41,15 +47,18 @@ func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader
return nil
}
func (s *Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(fileID)
func (s *Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(encryptor, fileID)
}
func (s *Storage) DeleteFile(fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(fileID)
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
}
func (s *Storage) Validate() error {
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.Type == "" {
return errors.New("storage type is required")
}
@@ -58,17 +67,21 @@ func (s *Storage) Validate() error {
return errors.New("storage name is required")
}
return s.getSpecificStorage().Validate()
return s.getSpecificStorage().Validate(encryptor)
}
func (s *Storage) TestConnection() error {
return s.getSpecificStorage().TestConnection()
func (s *Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().TestConnection(encryptor)
}
func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
}
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type

View File

@@ -13,6 +13,7 @@ import (
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strconv"
"testing"
@@ -147,13 +148,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
t.Run("Test_TestConnection_ConnectionSucceeds", func(t *testing.T) {
err := tc.storage.TestConnection()
err := tc.storage.TestConnection(encryptor)
assert.NoError(t, err, "TestConnection should succeed")
})
t.Run("Test_TestValidation_ValidationSucceeds", func(t *testing.T) {
err := tc.storage.Validate()
err := tc.storage.Validate(encryptor)
assert.NoError(t, err, "Validate should succeed")
})
@@ -163,10 +166,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.NoError(t, err, "GetFile should succeed")
defer file.Close()
@@ -180,13 +188,18 @@ func Test_Storage_BasicOperations(t *testing.T) {
require.NoError(t, err, "Should be able to read test file")
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
err = tc.storage.DeleteFile(fileID)
err = tc.storage.DeleteFile(encryptor, fileID)
assert.NoError(t, err, "DeleteFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.Error(t, err, "GetFile should fail for non-existent file")
if file != nil {
file.Close()
@@ -196,7 +209,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
// Try to delete a non-existent file
nonExistentID := uuid.New()
err := tc.storage.DeleteFile(nonExistentID)
err := tc.storage.DeleteFile(encryptor, nonExistentID)
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
})
})

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -37,8 +38,13 @@ func (s *AzureBlobStorage) TableName() string {
return "azure_blob_storages"
}
func (s *AzureBlobStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
client, err := s.getClient()
func (s *AzureBlobStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -59,8 +65,11 @@ func (s *AzureBlobStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file
return nil
}
func (s *AzureBlobStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
client, err := s.getClient()
func (s *AzureBlobStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
@@ -80,8 +89,8 @@ func (s *AzureBlobStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return response.Body, nil
}
func (s *AzureBlobStorage) DeleteFile(fileID uuid.UUID) error {
client, err := s.getClient()
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -105,7 +114,7 @@ func (s *AzureBlobStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (s *AzureBlobStorage) Validate() error {
func (s *AzureBlobStorage) Validate(encryptor encryption.FieldEncryptor) error {
if s.ContainerName == "" {
return errors.New("container name is required")
}
@@ -128,16 +137,11 @@ func (s *AzureBlobStorage) Validate() error {
return fmt.Errorf("invalid auth method: %s", s.AuthMethod)
}
_, err := s.getClient()
if err != nil {
return fmt.Errorf("invalid Azure Blob configuration: %w", err)
}
return nil
}
func (s *AzureBlobStorage) TestConnection() error {
client, err := s.getClient()
func (s *AzureBlobStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -192,6 +196,26 @@ func (s *AzureBlobStorage) HideSensitiveData() {
s.AccountKey = ""
}
func (s *AzureBlobStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ConnectionString != "" {
s.ConnectionString, err = encryptor.Encrypt(s.StorageID, s.ConnectionString)
if err != nil {
return fmt.Errorf("failed to encrypt Azure connection string: %w", err)
}
}
if s.AccountKey != "" {
s.AccountKey, err = encryptor.Encrypt(s.StorageID, s.AccountKey)
if err != nil {
return fmt.Errorf("failed to encrypt Azure account key: %w", err)
}
}
return nil
}
func (s *AzureBlobStorage) Update(incoming *AzureBlobStorage) {
s.AuthMethod = incoming.AuthMethod
s.ContainerName = incoming.ContainerName
@@ -225,13 +249,18 @@ func (s *AzureBlobStorage) buildBlobName(fileName string) string {
return prefix + fileName
}
func (s *AzureBlobStorage) getClient() (*azblob.Client, error) {
func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azblob.Client, error) {
var client *azblob.Client
var err error
switch s.AuthMethod {
case AuthMethodConnectionString:
client, err = azblob.NewClientFromConnectionString(s.ConnectionString, nil)
connectionString, decryptErr := encryptor.Decrypt(s.StorageID, s.ConnectionString)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure connection string: %w", decryptErr)
}
client, err = azblob.NewClientFromConnectionString(connectionString, nil)
if err != nil {
return nil, fmt.Errorf(
"failed to create Azure Blob client from connection string: %w",
@@ -239,8 +268,13 @@ func (s *AzureBlobStorage) getClient() (*azblob.Client, error) {
)
}
case AuthMethodAccountKey:
accountKey, decryptErr := encryptor.Decrypt(s.StorageID, s.AccountKey)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure account key: %w", decryptErr)
}
accountURL := s.buildAccountURL()
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, s.AccountKey)
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, accountKey)
if credErr != nil {
return nil, fmt.Errorf("failed to create Azure shared key credential: %w", credErr)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -30,11 +31,12 @@ func (s *GoogleDriveStorage) TableName() string {
}
func (s *GoogleDriveStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
filename := fileID.String()
@@ -68,9 +70,12 @@ func (s *GoogleDriveStorage) SaveFile(
})
}
func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
func (s *GoogleDriveStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
var result io.ReadCloser
err := s.withRetryOnAuth(func(driveService *drive.Service) error {
err := s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
return fmt.Errorf("failed to find backups folder: %w", err)
@@ -93,8 +98,11 @@ func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return result, err
}
func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
@@ -105,7 +113,7 @@ func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
})
}
func (s *GoogleDriveStorage) Validate() error {
func (s *GoogleDriveStorage) Validate(encryptor encryption.FieldEncryptor) error {
switch {
case s.ClientID == "":
return errors.New("client ID is required")
@@ -115,7 +123,12 @@ func (s *GoogleDriveStorage) Validate() error {
return errors.New("token JSON is required")
}
// Also validate that the token JSON contains a refresh token
// Skip JSON validation if token is already encrypted
if strings.HasPrefix(s.TokenJSON, "enc:") {
return nil
}
// Validate that the token JSON contains a refresh token
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON format: %w", err)
@@ -128,8 +141,8 @@ func (s *GoogleDriveStorage) Validate() error {
return nil
}
func (s *GoogleDriveStorage) TestConnection() error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
func (s *GoogleDriveStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
testFilename := "test-connection-" + uuid.New().String()
testData := []byte("test")
@@ -196,6 +209,26 @@ func (s *GoogleDriveStorage) HideSensitiveData() {
s.TokenJSON = ""
}
func (s *GoogleDriveStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ClientSecret != "" {
s.ClientSecret, err = encryptor.Encrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive client secret: %w", err)
}
}
if s.TokenJSON != "" {
s.TokenJSON, err = encryptor.Encrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive token JSON: %w", err)
}
}
return nil
}
func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
s.ClientID = incoming.ClientID
@@ -209,8 +242,11 @@ func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
}
// withRetryOnAuth executes the provided function with retry logic for authentication errors
func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) error {
driveService, err := s.getDriveService()
func (s *GoogleDriveStorage) withRetryOnAuth(
encryptor encryption.FieldEncryptor,
fn func(*drive.Service) error,
) error {
driveService, err := s.getDriveService(encryptor)
if err != nil {
return err
}
@@ -220,7 +256,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
// Try to refresh token and retry once
fmt.Printf("Google Drive auth error detected, attempting token refresh: %v\n", err)
if refreshErr := s.refreshToken(); refreshErr != nil {
if refreshErr := s.refreshToken(encryptor); refreshErr != nil {
// If refresh fails, return a more helpful error message
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
strings.Contains(refreshErr.Error(), "refresh token") {
@@ -237,7 +273,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
fmt.Printf("Token refresh successful, retrying operation\n")
// Get new service with refreshed token
driveService, err = s.getDriveService()
driveService, err = s.getDriveService(encryptor)
if err != nil {
return fmt.Errorf("failed to create service after token refresh: %w", err)
}
@@ -268,13 +304,24 @@ func (s *GoogleDriveStorage) isAuthError(err error) bool {
}
// refreshToken refreshes the OAuth2 token and updates the TokenJSON field
func (s *GoogleDriveStorage) refreshToken() error {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) error {
if err := s.Validate(encryptor); err != nil {
return err
}
// Decrypt credentials before use
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON: %w", err)
}
@@ -289,12 +336,12 @@ func (s *GoogleDriveStorage) refreshToken() error {
token.Expiry)
// Debug: Print the full token JSON structure (sensitive data masked)
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(s.TokenJSON))
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(tokenJSON))
ctx := context.Background()
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}
@@ -330,7 +377,7 @@ func (s *GoogleDriveStorage) refreshToken() error {
newToken.RefreshToken = token.RefreshToken
}
// Update the stored token JSON
// Update the stored token JSON (keep as plaintext in memory, encryption happens on save)
newTokenJSON, err := json.Marshal(newToken)
if err != nil {
return fmt.Errorf("failed to marshal refreshed token: %w", err)
@@ -368,13 +415,26 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen]
}
func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) getDriveService(
encryptor encryption.FieldEncryptor,
) (*drive.Service, error) {
if err := s.Validate(encryptor); err != nil {
return nil, err
}
// Decrypt credentials before use
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("invalid token JSON: %w", err)
}
@@ -382,7 +442,7 @@ func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"postgresus-backend/internal/config"
"postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"github.com/google/uuid"
@@ -23,7 +24,12 @@ func (l *LocalStorage) TableName() string {
return "local_storages"
}
func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (l *LocalStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
err := files_utils.EnsureDirectories([]string{
@@ -107,7 +113,10 @@ func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.R
return nil
}
func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
func (l *LocalStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -122,7 +131,7 @@ func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return file, nil
}
func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -136,11 +145,11 @@ func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (l *LocalStorage) Validate() error {
func (l *LocalStorage) Validate(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
testFile := filepath.Join(config.GetEnv().TempFolder, "test_connection")
f, err := os.Create(testFile)
if err != nil {
@@ -160,5 +169,9 @@ func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) HideSensitiveData() {
}
func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) Update(incoming *LocalStorage) {
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net"
"path/filepath"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -31,10 +32,15 @@ func (n *NASStorage) TableName() string {
return "nas_storages"
}
func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (n *NASStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
session, err := n.createSession()
session, err := n.createSession(encryptor)
if err != nil {
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to create NAS session: %w", err)
@@ -131,8 +137,11 @@ func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Rea
return nil
}
func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
session, err := n.createSession()
func (n *NASStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
session, err := n.createSession(encryptor)
if err != nil {
return nil, fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -168,8 +177,8 @@ func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
}, nil
}
func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
session, err := n.createSession()
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -202,7 +211,7 @@ func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (n *NASStorage) Validate() error {
func (n *NASStorage) Validate(encryptor encryption.FieldEncryptor) error {
if n.Host == "" {
return errors.New("NAS host is required")
}
@@ -219,12 +228,11 @@ func (n *NASStorage) Validate() error {
return errors.New("NAS port must be between 1 and 65535")
}
// Test the configuration by creating a session
return n.TestConnection()
return nil
}
func (n *NASStorage) TestConnection() error {
session, err := n.createSession()
func (n *NASStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to connect to NAS: %w", err)
}
@@ -255,6 +263,18 @@ func (n *NASStorage) HideSensitiveData() {
n.Password = ""
}
func (n *NASStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.Password != "" {
encrypted, err := encryptor.Encrypt(n.StorageID, n.Password)
if err != nil {
return fmt.Errorf("failed to encrypt NAS password: %w", err)
}
n.Password = encrypted
}
return nil
}
func (n *NASStorage) Update(incoming *NASStorage) {
n.Host = incoming.Host
n.Port = incoming.Port
@@ -269,18 +289,25 @@ func (n *NASStorage) Update(incoming *NASStorage) {
}
}
func (n *NASStorage) createSession() (*smb2.Session, error) {
func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.Session, error) {
// Create connection with timeout
conn, err := n.createConnection()
if err != nil {
return nil, err
}
// Decrypt password before use
password, err := encryptor.Decrypt(n.StorageID, n.Password)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to decrypt NAS password: %w", err)
}
// Create SMB2 dialer
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: n.Username,
Password: n.Password,
Password: password,
Domain: n.Domain,
},
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -31,8 +32,13 @@ func (s *S3Storage) TableName() string {
return "s3_storages"
}
func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
client, err := s.getClient()
func (s *S3Storage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -55,8 +61,11 @@ func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Read
return nil
}
func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
client, err := s.getClient()
func (s *S3Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
@@ -91,8 +100,8 @@ func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return object, nil
}
func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
client, err := s.getClient()
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -113,7 +122,7 @@ func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (s *S3Storage) Validate() error {
func (s *S3Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.S3Bucket == "" {
return errors.New("S3 bucket is required")
}
@@ -124,17 +133,11 @@ func (s *S3Storage) Validate() error {
return errors.New("S3 secret key is required")
}
// Try to create a client to validate the configuration
_, err := s.getClient()
if err != nil {
return fmt.Errorf("invalid S3 configuration: %w", err)
}
return nil
}
func (s *S3Storage) TestConnection() error {
client, err := s.getClient()
func (s *S3Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -195,6 +198,26 @@ func (s *S3Storage) HideSensitiveData() {
s.S3SecretKey = ""
}
func (s *S3Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.S3AccessKey != "" {
s.S3AccessKey, err = encryptor.Encrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 access key: %w", err)
}
}
if s.S3SecretKey != "" {
s.S3SecretKey, err = encryptor.Encrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 secret key: %w", err)
}
}
return nil
}
func (s *S3Storage) Update(incoming *S3Storage) {
s.S3Bucket = incoming.S3Bucket
s.S3Region = incoming.S3Region
@@ -228,7 +251,7 @@ func (s *S3Storage) buildObjectKey(fileName string) string {
return prefix + fileName
}
func (s *S3Storage) getClient() (*minio.Client, error) {
func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Client, error) {
endpoint := s.S3Endpoint
useSSL := true
@@ -244,6 +267,17 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", s.S3Region)
}
// Decrypt credentials before use
accessKey, err := encryptor.Decrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
}
secretKey, err := encryptor.Decrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
}
// Configure bucket lookup strategy
bucketLookup := minio.BucketLookupAuto
if s.S3UseVirtualHostedStyle {
@@ -252,7 +286,7 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
// Initialize the MinIO client
minioClient, err := minio.New(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(s.S3AccessKey, s.S3SecretKey, ""),
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
Secure: useSSL,
Region: s.S3Region,
BucketLookup: bucketLookup,

View File

@@ -7,6 +7,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -15,6 +16,7 @@ type StorageService struct {
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *StorageService) SaveStorage(
@@ -44,7 +46,11 @@ func (s *StorageService) SaveStorage(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -61,7 +67,11 @@ func (s *StorageService) SaveStorage(
} else {
storage.WorkspaceID = workspaceID
if err := storage.Validate(); err != nil {
if err := storage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := storage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -174,7 +184,7 @@ func (s *StorageService) TestStorageConnection(
return errors.New("insufficient permissions to test storage in this workspace")
}
err = storage.TestConnection()
err = storage.TestConnection(s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
storage.LastSaveError = &lastSaveError
@@ -207,7 +217,7 @@ func (s *StorageService) TestStorageConnectionDirect(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -216,7 +226,7 @@ func (s *StorageService) TestStorageConnectionDirect(
usingStorage = storage
}
return usingStorage.TestConnection()
return usingStorage.TestConnection(s.fieldEncryptor)
}
func (s *StorageService) GetStorageByID(

View File

@@ -309,15 +309,6 @@ func (s *UserService) ChangeUserPasswordByEmail(email string, newPassword string
}
func (s *UserService) ChangeUserPassword(userID uuid.UUID, newPassword string) error {
user, err := s.userRepository.GetUserByID(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if !user.HasPassword() {
return errors.New("user has no password set")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)

View File

@@ -0,0 +1,11 @@
package encryption
import users_repositories "postgresus-backend/internal/features/users/repositories"
var fieldEncryptor = &SecretKeyFieldEncryptor{
users_repositories.GetSecretKeyRepository(),
}
func GetFieldEncryptor() FieldEncryptor {
return fieldEncryptor
}

View File

@@ -0,0 +1,15 @@
package encryption
import "github.com/google/uuid"
type FieldEncryptor interface {
// Encrypt encrypts a plaintext string and returns an encrypted string.
// If the string is already encrypted, returns it as-is.
// Empty strings are returned unchanged.
Encrypt(itemID uuid.UUID, plaintext string) (string, error)
// Decrypt decrypts an encrypted string and returns a plaintext string.
// If the string is not encrypted, returns it as-is.
// Empty strings are returned unchanged.
Decrypt(itemID uuid.UUID, ciphertext string) (string, error)
}

View File

@@ -0,0 +1,121 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strings"
users_repositories "postgresus-backend/internal/features/users/repositories"
"github.com/google/uuid"
)
const encryptedPrefix = "enc:"
type SecretKeyFieldEncryptor struct {
secretKeyRepository *users_repositories.SecretKeyRepository
}
func (e *SecretKeyFieldEncryptor) Encrypt(itemID uuid.UUID, plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
if e.isEncrypted(plaintext) {
return plaintext, nil
}
masterKey, err := e.secretKeyRepository.GetSecretKey()
if err != nil {
return "", fmt.Errorf("failed to get master key: %w", err)
}
block, err := aes.NewCipher([]byte(masterKey)[:32])
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonce := e.deriveNonce(itemID, masterKey, gcm.NonceSize())
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
ciphertextBase64 := base64.StdEncoding.EncodeToString(ciphertext)
return fmt.Sprintf("%s%s:%s", encryptedPrefix, nonceBase64, ciphertextBase64), nil
}
func (e *SecretKeyFieldEncryptor) Decrypt(itemID uuid.UUID, ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil
}
if !e.isEncrypted(ciphertext) {
return ciphertext, nil
}
parts := strings.SplitN(ciphertext, ":", 3)
if len(parts) != 3 {
return "", errors.New("invalid encrypted format")
}
nonceBase64 := parts[1]
ciphertextBase64 := parts[2]
nonce, err := base64.StdEncoding.DecodeString(nonceBase64)
if err != nil {
return "", fmt.Errorf("failed to decode nonce: %w", err)
}
encryptedData, err := base64.StdEncoding.DecodeString(ciphertextBase64)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
masterKey, err := e.secretKeyRepository.GetSecretKey()
if err != nil {
return "", fmt.Errorf("failed to get master key: %w", err)
}
block, err := aes.NewCipher([]byte(masterKey)[:32])
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, encryptedData, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
func (e *SecretKeyFieldEncryptor) isEncrypted(value string) bool {
return strings.HasPrefix(value, encryptedPrefix)
}
func (e *SecretKeyFieldEncryptor) deriveNonce(
itemID uuid.UUID,
masterKey string,
nonceSize int,
) []byte {
h := hmac.New(sha256.New, []byte(masterKey))
h.Write(itemID[:])
hash := h.Sum(nil)
return hash[:nonceSize]
}

View File

@@ -0,0 +1,120 @@
package encryption
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_Encrypt_Decrypt_RoundTrip(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "my-secret-password"
encrypted, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
assert.NotEmpty(t, encrypted)
assert.NotEqual(t, plaintext, encrypted)
assert.Contains(t, encrypted, "enc:")
decrypted, err := encryptor.Decrypt(itemID, encrypted)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
func Test_Encrypt_EmptyString_ReturnsEmpty(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
encrypted, err := encryptor.Encrypt(itemID, "")
assert.NoError(t, err)
assert.Empty(t, encrypted)
}
func Test_Decrypt_EmptyString_ReturnsEmpty(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
decrypted, err := encryptor.Decrypt(itemID, "")
assert.NoError(t, err)
assert.Empty(t, decrypted)
}
func Test_Decrypt_PlaintextValue_ReturnsAsIs(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "not-encrypted-password"
decrypted, err := encryptor.Decrypt(itemID, plaintext)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
func Test_Encrypt_DetectsAlreadyEncryptedFormat(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
alreadyEncrypted := "enc:nonce:ciphertext"
result, err := encryptor.Encrypt(itemID, alreadyEncrypted)
assert.NoError(t, err)
assert.Equal(t, alreadyEncrypted, result)
}
func Test_Encrypt_SamePlaintext_DifferentItemIDs_ProducesDifferentCiphertext(t *testing.T) {
encryptor := GetFieldEncryptor()
plaintext := "shared-secret"
itemID1 := uuid.New()
itemID2 := uuid.New()
encrypted1, err := encryptor.Encrypt(itemID1, plaintext)
assert.NoError(t, err)
encrypted2, err := encryptor.Encrypt(itemID2, plaintext)
assert.NoError(t, err)
assert.NotEqual(t, encrypted1, encrypted2)
decrypted1, err := encryptor.Decrypt(itemID1, encrypted1)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted1)
decrypted2, err := encryptor.Decrypt(itemID2, encrypted2)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted2)
}
func Test_Encrypt_AlreadyEncrypted_ReturnsAsIs(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "my-password"
encrypted1, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
encrypted2, err := encryptor.Encrypt(itemID, encrypted1)
assert.NoError(t, err)
assert.Equal(t, encrypted1, encrypted2)
}
func Test_Decrypt_MalformedData_ReturnsError(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
_, err := encryptor.Decrypt(itemID, "enc:invalid")
assert.Error(t, err)
_, err = encryptor.Decrypt(itemID, "enc:invalid:invalid-base64")
assert.Error(t, err)
}
func Test_EncryptedFormat_ContainsPrefix(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "test-secret"
encrypted, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
assert.Contains(t, encrypted, "enc:")
}