Compare commits

...

6 Commits

Author SHA1 Message Date
Rostislav Dugin
95c833b619 FIX (backups): Fix passing encypted password to .pgpass 2025-11-19 17:10:19 +03:00
Rostislav Dugin
878fad5747 FEATURE (encryption): Add encyption for secrets in notifiers and storages 2025-11-18 21:23:59 +03:00
Rostislav Dugin
6ff3096695 FIX (password reset): Allow to change user password even if password was not set before 2025-11-17 20:20:31 +03:00
Rostislav Dugin
b4b514c2d5 FEATURE (encryption): Add backups encryption 2025-11-17 14:33:37 +03:00
Rostislav Dugin
da0fec6624 FEATURE (azure): Add Azure Blob Storage 2025-11-16 23:38:20 +03:00
Rostislav Dugin
408675023a FEATURE (s3): Add support of virtual-styled-domains and S3 prefix 2025-11-16 11:22:03 +03:00
108 changed files with 4869 additions and 1030 deletions

View File

@@ -137,6 +137,8 @@ jobs:
# testing S3
TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing Azure Blob
TEST_AZURITE_BLOB_PORT=10000
# testing NAS
TEST_NAS_PORT=7006
# testing Telegram
@@ -165,6 +167,9 @@ jobs:
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
- name: Create data and temp directories
run: |
# Create directories that are used for backups and restore

View File

@@ -40,13 +40,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(docs)</a>
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
- **Local storage**: Keep backups on your VPS/server
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
- **Secure**: All data stays under your control
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(docs)</a>
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
- **Real-time updates**: Success and failure notifications
@@ -58,6 +58,13 @@
- **SSL support**: Secure connections available
- **Easy restoration**: One-click restore from any backup
### 🔒 **Backup Encryption** <a href="https://postgresus.com/encryption">(docs)</a>
- **AES-256-GCM encryption**: Enterprise-grade protection for backup files
- **Zero-trust storage**: Encrypted backups are useless so you can keep in shared storages like S3, Azure Blob Storage, etc.
- **Optionality**: Encrypted backups are optional and can be enabled or disabled if you wish
- **Download unencrypted**: You can still download unencrypted backups via the 'Download' button to use them in `pg_restore` or other tools.
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 22 KiB

After

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -31,4 +31,6 @@ TEST_MINIO_CONSOLE_PORT=9001
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=
TEST_TELEGRAM_CHAT_ID=
TEST_TELEGRAM_CHAT_ID=
# testing Azure Blob Storage
TEST_AZURITE_BLOB_PORT=10000

View File

@@ -31,6 +31,14 @@ services:
container_name: test-minio
command: server /data --console-address ":9001"
# Test Azurite container
test-azurite:
image: mcr.microsoft.com/azure-storage/azurite
ports:
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
container_name: test-azurite
command: azurite-blob --blobHost 0.0.0.0
# Test PostgreSQL containers
test-postgres-12:
image: postgres:12

View File

@@ -3,6 +3,8 @@ module postgresus-backend
go 1.23.3
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -15,16 +17,18 @@ require (
github.com/lib/pq v1.10.9
github.com/minio/minio-go/v7 v7.0.92
github.com/shirou/gopsutil/v4 v4.25.5
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.11.1
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
golang.org/x/crypto v0.39.0
golang.org/x/crypto v0.41.0
golang.org/x/time v0.12.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
require github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
require (
cloud.google.com/go/auth v0.16.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
@@ -99,12 +103,12 @@ require (
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
golang.org/x/arch v0.17.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/tools v0.33.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/tools v0.35.0 // indirect
google.golang.org/api v0.239.0
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect

View File

@@ -6,6 +6,18 @@ cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeO
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 h1:KpMC6LFL7mqpExyMC9jVOYRiVhLmamjeZfRsUpB7l4s=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8=
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -80,6 +92,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -131,6 +145,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -159,6 +175,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
@@ -180,8 +198,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
@@ -216,25 +234,25 @@ golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -247,8 +265,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -257,15 +275,15 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo=
google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=

View File

@@ -44,6 +44,8 @@ type EnvVariables struct {
TestMinioPort string `env:"TEST_MINIO_PORT"`
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
TestNASPort string `env:"TEST_NAS_PORT"`
// oauth
@@ -184,6 +186,11 @@ func loadEnvVariables() {
os.Exit(1)
}
if env.TestAzuriteBlobPort == "" {
log.Error("TEST_AZURITE_BLOB_PORT is empty")
os.Exit(1)
}
if env.TestNASPort == "" {
log.Error("TEST_NAS_PORT is empty")
os.Exit(1)

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/config"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/period"
"time"
)
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
continue
}
err = storage.DeleteFile(backup.ID)
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}

View File

@@ -26,6 +26,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -524,7 +525,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextMgr.RegisterBackup(backup.ID, func() {})
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -700,7 +701,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -7,7 +7,9 @@ import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"time"
)
@@ -23,6 +25,8 @@ var backupService = &BackupService{
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},

View File

@@ -1,5 +1,10 @@
package backups
import (
"io"
"postgresus-backend/internal/features/backups/backups/encryption"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
@@ -12,3 +17,12 @@ type GetBackupsResponse struct {
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type decryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
}

View File

@@ -0,0 +1,156 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type DecryptionReader struct {
baseReader io.Reader
cipher cipher.AEAD
buffer []byte
nonce []byte
chunkIndex uint64
headerRead bool
eof bool
}
func NewDecryptionReader(
baseReader io.Reader,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*DecryptionReader, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
reader := &DecryptionReader{
baseReader,
aesgcm,
make([]byte, 0),
nonce,
0,
false,
false,
}
if err := reader.readAndValidateHeader(salt, nonce); err != nil {
return nil, err
}
return reader, nil
}
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
for len(r.buffer) < len(p) && !r.eof {
if err := r.readAndDecryptChunk(); err != nil {
if err == io.EOF {
r.eof = true
break
}
return 0, err
}
}
if len(r.buffer) == 0 {
return 0, io.EOF
}
n = copy(p, r.buffer)
r.buffer = r.buffer[n:]
return n, nil
}
func (r *DecryptionReader) readAndValidateHeader(expectedSalt, expectedNonce []byte) error {
header := make([]byte, HeaderLen)
if _, err := io.ReadFull(r.baseReader, header); err != nil {
return fmt.Errorf("failed to read header: %w", err)
}
magic := string(header[0:MagicBytesLen])
if magic != MagicBytes {
return fmt.Errorf("invalid magic bytes: expected %s, got %s", MagicBytes, magic)
}
salt := header[MagicBytesLen : MagicBytesLen+SaltLen]
nonce := header[MagicBytesLen+SaltLen : MagicBytesLen+SaltLen+NonceLen]
if string(salt) != string(expectedSalt) {
return fmt.Errorf("salt mismatch in file header")
}
if string(nonce) != string(expectedNonce) {
return fmt.Errorf("nonce mismatch in file header")
}
r.headerRead = true
return nil
}
func (r *DecryptionReader) readAndDecryptChunk() error {
lengthBuf := make([]byte, 4)
if _, err := io.ReadFull(r.baseReader, lengthBuf); err != nil {
return err
}
chunkLen := binary.BigEndian.Uint32(lengthBuf)
if chunkLen == 0 || chunkLen > ChunkSize+16 {
return fmt.Errorf("invalid chunk length: %d", chunkLen)
}
encrypted := make([]byte, chunkLen)
if _, err := io.ReadFull(r.baseReader, encrypted); err != nil {
return fmt.Errorf("failed to read encrypted chunk: %w", err)
}
chunkNonce := r.generateChunkNonce()
decrypted, err := r.cipher.Open(nil, chunkNonce, encrypted, nil)
if err != nil {
return fmt.Errorf(
"failed to decrypt chunk (authentication failed - file may be corrupted or tampered): %w",
err,
)
}
r.buffer = append(r.buffer, decrypted...)
r.chunkIndex++
return nil
}
func (r *DecryptionReader) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, r.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], r.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,147 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type EncryptionWriter struct {
baseWriter io.Writer
cipher cipher.AEAD
buffer []byte
nonce []byte
salt []byte
chunkIndex uint64
headerWritten bool
}
func NewEncryptionWriter(
baseWriter io.Writer,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*EncryptionWriter, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
writer := &EncryptionWriter{
baseWriter: baseWriter,
cipher: aesgcm,
buffer: make([]byte, 0, ChunkSize),
nonce: nonce,
chunkIndex: 0,
headerWritten: false,
salt: salt, // Store salt for lazy header writing
}
return writer, nil
}
func (w *EncryptionWriter) Write(p []byte) (n int, err error) {
// Write header on first write (lazy initialization)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return 0, fmt.Errorf("failed to write header: %w", err)
}
}
n = len(p)
w.buffer = append(w.buffer, p...)
for len(w.buffer) >= ChunkSize {
chunk := w.buffer[:ChunkSize]
if err := w.encryptAndWriteChunk(chunk); err != nil {
return 0, err
}
w.buffer = w.buffer[ChunkSize:]
}
return n, nil
}
func (w *EncryptionWriter) Close() error {
// Write header if it hasn't been written yet (in case Close is called without any writes)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
}
if len(w.buffer) > 0 {
if err := w.encryptAndWriteChunk(w.buffer); err != nil {
return err
}
w.buffer = nil
}
return nil
}
func (w *EncryptionWriter) writeHeader(salt, nonce []byte) error {
header := make([]byte, HeaderLen)
copy(header[0:MagicBytesLen], []byte(MagicBytes))
copy(header[MagicBytesLen:MagicBytesLen+SaltLen], salt)
copy(header[MagicBytesLen+SaltLen:MagicBytesLen+SaltLen+NonceLen], nonce)
_, err := w.baseWriter.Write(header)
if err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
w.headerWritten = true
return nil
}
func (w *EncryptionWriter) encryptAndWriteChunk(chunk []byte) error {
chunkNonce := w.generateChunkNonce()
encrypted := w.cipher.Seal(nil, chunkNonce, chunk, nil)
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(len(encrypted)))
if _, err := w.baseWriter.Write(lengthBuf); err != nil {
return fmt.Errorf("failed to write chunk length: %w", err)
}
if _, err := w.baseWriter.Write(encrypted); err != nil {
return fmt.Errorf("failed to write encrypted chunk: %w", err)
}
w.chunkIndex++
return nil
}
func (w *EncryptionWriter) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, w.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], w.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,387 @@
package encryption
import (
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_EncryptDecryptRoundTrip_ReturnsOriginalData(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte(
"This is a test backup data that should be encrypted and then decrypted successfully.",
)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted := make([]byte, len(originalData))
n, err = io.ReadFull(reader, decrypted)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptDecryptRoundTrip_LargeData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := make([]byte, 100*1024)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptionWriter_MultipleWrites_CombinesCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
part1 := []byte("First part of data. ")
part2 := []byte("Second part of data. ")
part3 := []byte("Third part of data.")
expectedData := append(append(part1, part2...), part3...)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(part1)
require.NoError(t, err)
_, err = writer.Write(part2)
require.NoError(t, err)
_, err = writer.Write(part3)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, expectedData, decrypted)
}
func Test_DecryptionReader_InvalidHeader_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
invalidHeader := make([]byte, HeaderLen)
copy(invalidHeader, []byte("INVALID!"))
invalidData := bytes.NewBuffer(invalidHeader)
_, err = NewDecryptionReader(invalidData, masterKey, backupID, salt, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid magic bytes")
}
func Test_DecryptionReader_TamperedData_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This data will be tampered with.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
encryptedBytes := encrypted.Bytes()
if len(encryptedBytes) > HeaderLen+10 {
encryptedBytes[HeaderLen+10] ^= 0xFF
}
tamperedBuffer := bytes.NewBuffer(encryptedBytes)
reader, err := NewDecryptionReader(tamperedBuffer, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_DeriveBackupKey_SameInputs_ReturnsSameKey(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
assert.Equal(t, key1, key2)
}
func Test_DeriveBackupKey_DifferentInputs_ReturnsDifferentKeys(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID1 := uuid.New()
backupID2 := uuid.New()
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey1, backupID1, salt1)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey2, backupID1, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key2)
key3, err := DeriveBackupKey(masterKey1, backupID2, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key3)
key4, err := DeriveBackupKey(masterKey1, backupID1, salt2)
require.NoError(t, err)
assert.NotEqual(t, key1, key4)
}
func Test_EncryptionWriter_PartialChunk_HandledCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
smallData := []byte("Small data less than chunk size")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(smallData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, smallData, decrypted)
}
func Test_GenerateSalt_ReturnsCorrectLength(t *testing.T) {
salt, err := GenerateSalt()
require.NoError(t, err)
assert.Equal(t, SaltLen, len(salt))
}
func Test_GenerateSalt_GeneratesUniqueSalts(t *testing.T) {
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
assert.NotEqual(t, salt1, salt2)
}
func Test_GenerateNonce_ReturnsCorrectLength(t *testing.T) {
nonce, err := GenerateNonce()
require.NoError(t, err)
assert.Equal(t, NonceLen, len(nonce))
}
func Test_GenerateNonce_GeneratesUniqueNonces(t *testing.T) {
nonce1, err := GenerateNonce()
require.NoError(t, err)
nonce2, err := GenerateNonce()
require.NoError(t, err)
assert.NotEqual(t, nonce1, nonce2)
}
func Test_DecryptionReader_WrongMasterKey_ReturnsError(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("Secret data")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey1, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey2, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_EncryptionWriter_EmptyData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, 0, len(decrypted))
}
func Test_EncryptionWriter_MultipleChunks_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
dataSize := ChunkSize*3 + 1000
originalData := make([]byte, dataSize)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_DecryptionReader_SmallReads_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This is test data that will be read in small chunks.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
var decrypted []byte
buf := make([]byte, 5)
for {
n, err := reader.Read(buf)
if n > 0 {
decrypted = append(decrypted, buf[:n]...)
}
if err == io.EOF {
break
}
require.NoError(t, err)
}
assert.Equal(t, originalData, decrypted)
}

View File

@@ -0,0 +1,52 @@
package encryption
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"github.com/google/uuid"
"golang.org/x/crypto/pbkdf2"
)
const (
MagicBytes = "PGRSUS01"
MagicBytesLen = 8
SaltLen = 32
NonceLen = 12
ReservedLen = 12
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
ChunkSize = 32 * 1024
PBKDF2Iterations = 100000
)
func DeriveBackupKey(masterKey string, backupID uuid.UUID, salt []byte) ([]byte, error) {
if masterKey == "" {
return nil, fmt.Errorf("master key cannot be empty")
}
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes", SaltLen)
}
keyMaterial := []byte(masterKey + backupID.String())
derivedKey := pbkdf2.Key(keyMaterial, salt, PBKDF2Iterations, 32, sha256.New)
return derivedKey, nil
}
func GenerateSalt() ([]byte, error) {
salt := make([]byte, SaltLen)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("failed to generate salt: %w", err)
}
return salt, nil
}
func GenerateNonce() ([]byte, error) {
nonce := make([]byte, NonceLen)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
return nonce, nil
}

View File

@@ -3,6 +3,7 @@ package backups
import (
"context"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
@@ -29,7 +30,7 @@ type CreateBackupUsecase interface {
backupProgressListener func(
completedMBs float64,
),
) error
) (*usecases_postgresql.BackupMetadata, error)
}
type BackupRemoveListener interface {

View File

@@ -1,6 +1,7 @@
package backups
import (
backups_config "postgresus-backend/internal/features/backups/config"
"time"
"github.com/google/uuid"
@@ -19,5 +20,9 @@ type Backup struct {
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}

View File

@@ -2,17 +2,21 @@ package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
users_repositories "postgresus-backend/internal/features/users/repositories"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
util_encryption "postgresus-backend/internal/util/encryption"
"slices"
"strings"
"time"
@@ -27,6 +31,8 @@ type BackupService struct {
notifierService *notifiers.NotifierService
notificationSender NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyRepo *users_repositories.SecretKeyRepository
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
@@ -34,9 +40,9 @@ type BackupService struct {
backupRemoveListeners []BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextMgr *BackupContextManager
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
@@ -253,10 +259,10 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextMgr.RegisterBackup(backup.ID, cancel)
defer s.backupContextMgr.UnregisterBackup(backup.ID)
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
err = s.createBackupUseCase.Execute(
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
@@ -280,7 +286,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
@@ -326,6 +332,13 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
@@ -463,7 +476,7 @@ func (s *BackupService) CancelBackup(
return errors.New("backup is not in progress")
}
if err := s.backupContextMgr.CancelBackup(backupID); err != nil {
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
return err
}
@@ -509,11 +522,6 @@ func (s *BackupService) GetBackupFile(
return nil, errors.New("insufficient permissions to download backup for this database")
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return nil, err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
@@ -524,7 +532,7 @@ func (s *BackupService) GetBackupFile(
database.WorkspaceID,
)
return storage.GetFile(backup.ID)
return s.getBackupReader(backupID)
}
func (s *BackupService) deleteBackup(backup *Backup) error {
@@ -539,7 +547,7 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
return err
}
err = storage.DeleteFile(backup.ID)
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
@@ -579,3 +587,91 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
return nil
}
// GetBackupReader returns a reader for the backup file
// If encrypted, wraps with DecryptionReader
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("failed to find backup: %w", err)
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
// If not encrypted, return raw reader
if backup.Encryption == backups_config.BackupEncryptionNone {
s.logger.Info("Returning non-encrypted backup", "backupId", backupID)
return fileReader, nil
}
// Decrypt on-the-fly for encrypted backups
if backup.Encryption != backups_config.BackupEncryptionEncrypted {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("unsupported encryption type: %s", backup.Encryption)
}
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("backup marked as encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := s.secretKeyRepo.GetSecretKey()
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to get master key: %w", err)
}
// Decode salt and IV
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode IV: %w", err)
}
// Wrap with decrypting reader
decryptionReader, err := encryption.NewDecryptionReader(
fileReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to create decrypting reader: %w", err)
}
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
}, nil
}

View File

@@ -3,14 +3,17 @@ package backups
import (
"context"
"errors"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_repositories "postgresus-backend/internal/features/users/repositories"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strings"
"testing"
@@ -53,11 +56,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}
@@ -99,11 +104,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}
@@ -122,11 +129,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil, // auditLogService
nil,
NewBackupContextManager(),
}
@@ -171,9 +180,9 @@ func (uc *CreateFailedBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return errors.New("backup failed")
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct {
@@ -188,7 +197,11 @@ func (uc *CreateSuccessBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return nil
return &usecases_postgresql.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}

View File

@@ -15,7 +15,7 @@ type CreateBackupUsecase struct {
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
}
// Execute creates a backup of the database and returns the backup size in MB
// Execute creates a backup of the database and returns the backup metadata
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
@@ -25,7 +25,7 @@ func (uc *CreateBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
if database.Type == databases.DatabaseTypePostgres {
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
@@ -37,5 +37,5 @@ func (uc *CreateBackupUsecase) Execute(
)
}
return errors.New("database type not supported")
return nil, errors.New("database type not supported")
}

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -14,17 +15,35 @@ import (
"time"
"postgresus-backend/internal/config"
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"github.com/google/uuid"
)
const (
backupTimeout = 23 * time.Hour
shutdownCheckInterval = 1 * time.Second
copyBufferSize = 32 * 1024
progressReportIntervalMB = 1.0
pgConnectTimeout = 30
compressionLevel = 5
defaultBackupLimit = 1000
exitCodeAccessViolation = -1073741819
exitCodeGenericError = 1
exitCodeConnectionError = 2
)
type CreatePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyRepo *users_repositories.SecretKeyRepository
fieldEncryptor encryption.FieldEncryptor
}
// Execute creates a backup of the database
@@ -37,7 +56,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*BackupMetadata, error) {
uc.logger.Info(
"Creating PostgreSQL backup via pg_dump custom format",
"databaseId",
@@ -47,38 +66,24 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
)
if !backupConfig.IsBackupsEnabled {
return fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
pg := db.Postgresql
if pg == nil {
return fmt.Errorf("postgresql database configuration is required for pg_dump backups")
return nil, fmt.Errorf("postgresql database configuration is required for pg_dump backups")
}
if pg.Database == nil || *pg.Database == "" {
return fmt.Errorf("database name is required for pg_dump backups")
return nil, fmt.Errorf("database name is required for pg_dump backups")
}
args := []string{
"-Fc", // custom format with built-in compression
"--no-password", // Use environment variable for password, prevent prompts
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose", // Add verbose output to help with debugging
}
args := uc.buildPgDumpArgs(pg)
// Use zstd compression level 5 for PostgreSQL 16+ (better compression and speed)
// Fall back to gzip compression level 5 for older versions (12-15)
if pg.Version == tools.PostgresqlVersion12 || pg.Version == tools.PostgresqlVersion13 ||
pg.Version == tools.PostgresqlVersion14 || pg.Version == tools.PostgresqlVersion15 {
args = append(args, "-Z", "5")
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
} else {
args = append(args, "--compress=zstd:5")
uc.logger.Info("Using zstd compression level 5", "version", pg.Version)
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, pg.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
}
return uc.streamToStorage(
@@ -92,7 +97,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
config.GetEnv().PostgresesInstallDir,
),
args,
pg.Password,
decryptedPassword,
storage,
db,
backupProgressListener,
@@ -110,36 +115,15 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storage *storages.Storage,
db *databases.Database,
backupProgressListener func(completedMBs float64),
) error {
) (*BackupMetadata, error) {
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
// if backup not fit into 23 hours, Postgresus
// seems not to work for such database size
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
ctx, cancel := uc.createBackupContext(parentCtx)
defer cancel()
// Monitor for shutdown and cancel context if needed
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
// Create temporary .pgpass file as a more reliable alternative to PGPASSWORD
pgpassFile, err := uc.createTempPgpassFile(db.Postgresql, password)
pgpassFile, err := uc.setupPgpassFile(db.Postgresql, password)
if err != nil {
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
return nil, err
}
defer func() {
if pgpassFile != "" {
@@ -147,87 +131,21 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
}
}()
// Verify .pgpass file was created successfully
if pgpassFile == "" {
return fmt.Errorf("temporary .pgpass file was not created")
}
// Verify .pgpass file was created correctly
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return fmt.Errorf("failed to verify .pgpass file: %w", err)
}
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
// Start with system environment variables to preserve Windows PATH, SystemRoot, etc.
cmd.Env = os.Environ()
// Use the .pgpass file for authentication
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
// Debug password setup (without exposing the actual password)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", backupConfig.CpuCount,
)
// Add PostgreSQL-specific environment variables
cmd.Env = append(cmd.Env, "PGCLIENTENCODING=UTF8")
cmd.Env = append(cmd.Env, "PGCONNECT_TIMEOUT=30")
// Add encoding-related environment variables to handle character encoding issues
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
// Add PostgreSQL-specific encoding settings
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
shouldRequireSSL := db.Postgresql.IsHttps
// Require SSL when explicitly configured
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", db.Postgresql.IsHttps)
} else {
// SSL not explicitly required, but prefer it if available
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", db.Postgresql.IsHttps)
}
// Set other SSL parameters to avoid certificate issues
cmd.Env = append(cmd.Env, "PGSSLCERT=") // No client certificate
cmd.Env = append(cmd.Env, "PGSSLKEY=") // No client key
cmd.Env = append(cmd.Env, "PGSSLROOTCERT=") // No root certificate verification
cmd.Env = append(cmd.Env, "PGSSLCRL=") // No certificate revocation list
// Verify executable exists and is accessible
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf(
"PostgreSQL executable not found or not accessible: %s - %w",
pgBin,
err,
)
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, backupConfig.CpuCount, pgBin); err != nil {
return nil, err
}
pgStdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe: %w", err)
return nil, fmt.Errorf("stdout pipe: %w", err)
}
pgStderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe: %w", err)
return nil, fmt.Errorf("stderr pipe: %w", err)
}
// Capture stderr in a separate goroutine to ensure we don't miss any error output
@@ -237,23 +155,31 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
stderrCh <- stderrOutput
}()
// A pipe connecting pg_dump output → storage
storageReader, storageWriter := io.Pipe()
// Create a counting writer to track bytes
countingWriter := &CountingWriter{writer: storageWriter}
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backupConfig,
storageWriter,
)
if err != nil {
return nil, err
}
countingWriter := &CountingWriter{writer: finalWriter}
// The backup ID becomes the object key / filename in storage
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErrCh <- storage.SaveFile(uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErrCh <- saveErr
}()
// Start pg_dump
if err = cmd.Start(); err != nil {
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
return nil, fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
}
// Copy pg output directly to storage with shutdown checks
@@ -278,26 +204,14 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// Check for shutdown or cancellation before finalizing
select {
case <-ctx.Done():
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
}
<-saveErrCh // Wait for storage to finish
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
return nil, uc.checkCancellationReason()
default:
}
// Close the pipe writer to signal end of data
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
<-saveErrCh
return nil, err
}
// Wait until storage ends reading
@@ -312,149 +226,34 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
switch {
case waitErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
// Enhanced error handling for PostgreSQL connection and SSL issues
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf(
"%s failed: %v stderr: %s",
filepath.Base(pgBin),
waitErr,
stderrStr,
)
// Check for specific PostgreSQL error patterns
if exitErr, ok := waitErr.(*exec.ExitError); ok {
exitCode := exitErr.ExitCode()
// Enhanced debugging for exit status 1 with empty stderr
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=30",
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
},
)
errorMsg = fmt.Sprintf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"args", args,
"exitCode", fmt.Sprintf("0x%X", uint32(exitCode)),
)
errorMsg = fmt.Sprintf(
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
} else if exitCode == 1 || exitCode == 2 {
// Check for common connection and authentication issues
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). The server may not allow connections from your IP address or may require different authentication settings. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. stderr: %s",
len(password),
password == "",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
errorMsg = fmt.Sprintf(
"PostgreSQL SSL connection failed. The server may require SSL encryption or have SSL configuration issues. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection refused. Check if the server is running and accessible from your network. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "timeout") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
}
}
return errors.New(errorMsg)
return nil, uc.buildPgDumpErrorMessage(waitErr, stderrOutput, pgBin, args, password)
case copyErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("copy to storage: %w", copyErr)
return nil, fmt.Errorf("copy to storage: %w", copyErr)
case saveErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("save to storage: %w", saveErr)
return nil, fmt.Errorf("save to storage: %w", saveErr)
}
return nil
return &backupMetadata, nil
}
// copyWithShutdownCheck copies data from src to dst while checking for shutdown
func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
backupProgressListener func(completedMBs float64),
) (int64, error) {
buf := make([]byte, 32*1024) // 32KB buffer
buf := make([]byte, copyBufferSize)
var totalBytesWritten int64
// Progress reporting interval - report every 1MB of data
var lastReportedMB float64
const reportIntervalMB = 1.0
for {
select {
@@ -487,12 +286,9 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
totalBytesWritten += int64(bytesWritten)
// Report progress based on total size
if backupProgressListener != nil {
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
// Only report if we've written at least 1MB more data than last report
if currentSizeMB >= lastReportedMB+reportIntervalMB {
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
backupProgressListener(currentSizeMB)
lastReportedMB = currentSizeMB
}
@@ -503,7 +299,6 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
@@ -511,12 +306,412 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
return totalBytesWritten, nil
}
// containsIgnoreCase checks if a string contains a substring, ignoring case
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlDatabase) []string {
args := []string{
"-Fc",
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose",
}
compressionArgs := uc.getCompressionArgs(pg.Version)
return append(args, compressionArgs...)
}
func (uc *CreatePostgresqlBackupUsecase) getCompressionArgs(
version tools.PostgresqlVersion,
) []string {
if uc.isOlderPostgresVersion(version) {
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", version)
return []string{"-Z", strconv.Itoa(compressionLevel)}
}
uc.logger.Info("Using zstd compression level 5", "version", version)
return []string{fmt.Sprintf("--compress=zstd:%d", compressionLevel)}
}
func (uc *CreatePostgresqlBackupUsecase) isOlderPostgresVersion(
version tools.PostgresqlVersion,
) bool {
return version == tools.PostgresqlVersion12 ||
version == tools.PostgresqlVersion13 ||
version == tools.PostgresqlVersion14 ||
version == tools.PostgresqlVersion15
}
func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
parentCtx context.Context,
) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
go func() {
ticker := time.NewTicker(shutdownCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
return ctx, cancel
}
func (uc *CreatePostgresqlBackupUsecase) setupPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
) (string, error) {
pgpassFile, err := uc.createTempPgpassFile(pgConfig, password)
if err != nil {
return "", fmt.Errorf("failed to create temporary .pgpass file: %w", err)
}
if pgpassFile == "" {
return "", fmt.Errorf("temporary .pgpass file was not created")
}
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return "", fmt.Errorf("failed to verify .pgpass file: %w", err)
}
return pgpassFile, nil
}
func (uc *CreatePostgresqlBackupUsecase) setupPgEnvironment(
cmd *exec.Cmd,
pgpassFile string,
shouldRequireSSL bool,
password string,
cpuCount int,
pgBin string,
) error {
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", cpuCount,
)
cmd.Env = append(cmd.Env,
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT="+strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
)
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", shouldRequireSSL)
} else {
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", shouldRequireSSL)
}
cmd.Env = append(cmd.Env,
"PGSSLCERT=",
"PGSSLKEY=",
"PGSSLROOTCERT=",
"PGSSLCRL=",
)
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf("PostgreSQL executable not found or not accessible: %s - %w", pgBin, err)
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
metadata := BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyRepo.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
saveErrCh chan error,
) {
if encryptionWriter != nil {
go func() {
if closeErr := encryptionWriter.Close(); closeErr != nil {
uc.logger.Error(
"Failed to close encrypting writer during cancellation",
"error",
closeErr,
)
}
}()
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
}
<-saveErrCh
}
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
) error {
encryptionCloseErrCh := make(chan error, 1)
if encryptionWriter != nil {
go func() {
closeErr := encryptionWriter.Close()
if closeErr != nil {
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
}
encryptionCloseErrCh <- closeErr
}()
} else {
encryptionCloseErrCh <- nil
}
encryptionCloseErr := <-encryptionCloseErrCh
if encryptionCloseErr != nil {
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
}
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer", "error", err)
return err
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellation(ctx context.Context) error {
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
return nil
}
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellationReason() error {
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
}
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
password string,
) error {
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf("%s failed: %v stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
return errors.New(errorMsg)
}
exitCode := exitErr.ExitCode()
if exitCode == exitCodeGenericError && strings.TrimSpace(stderrStr) == "" {
return uc.handleExitCode1NoStderr(pgBin, args)
}
if exitCode == exitCodeAccessViolation {
return uc.handleAccessViolation(pgBin, stderrStr)
}
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
return uc.handleConnectionErrors(stderrStr, password)
}
return errors.New(errorMsg)
}
func (uc *CreatePostgresqlBackupUsecase) handleExitCode1NoStderr(
pgBin string,
args []string,
) error {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=" + strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
},
)
return fmt.Errorf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
}
func (uc *CreatePostgresqlBackupUsecase) handleAccessViolation(
pgBin string,
stderrStr string,
) error {
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"exitCode", "0xC0000005",
)
return fmt.Errorf(
"%s crashed with access violation (0xC0000005). "+
"This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. "+
"stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
}
func (uc *CreatePostgresqlBackupUsecase) handleConnectionErrors(
stderrStr string,
password string,
) error {
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
return fmt.Errorf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). "+
"The server may not allow connections from your IP address or may require different authentication settings. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "no password supplied") ||
containsIgnoreCase(stderrStr, "fe_sendauth") {
return fmt.Errorf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. "+
"stderr: %s",
len(password),
password == "",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
return fmt.Errorf(
"PostgreSQL SSL connection failed. "+
"The server may require SSL encryption or have SSL configuration issues. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
return fmt.Errorf(
"PostgreSQL connection refused. "+
"Check if the server is running and accessible from your network. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "authentication") ||
containsIgnoreCase(stderrStr, "password") {
return fmt.Errorf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "timeout") {
return fmt.Errorf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
return fmt.Errorf("PostgreSQL connection or authentication error. stderr: %s", stderrStr)
}
// createTempPgpassFile creates a temporary .pgpass file with the given password
func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
@@ -532,7 +727,6 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
password,
)
// it always create unique directory like /tmp/pgpass-1234567890
tempDir, err := os.MkdirTemp("", "pgpass")
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
@@ -546,3 +740,7 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
return pgpassFile, nil
}
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
}

View File

@@ -1,11 +1,15 @@
package usecases_postgresql
import (
users_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
logger.GetLogger(),
users_repositories.GetSecretKeyRepository(),
encryption.GetFieldEncryptor(),
}
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {

View File

@@ -0,0 +1,15 @@
package usecases_postgresql
import backups_config "postgresus-backend/internal/features/backups/config"
type EncryptionMetadata struct {
Salt string
IV string
Encryption backups_config.BackupEncryption
}
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
}

View File

@@ -20,15 +20,15 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
// SaveBackupConfig
// @Summary Save backup configuration
// @Description Save or update backup configuration for a database
// @Description Save or update backup configuration for a database. Encryption can be set to NONE (no encryption) or ENCRYPTED (AES-256-GCM encryption).
// @Tags backup-configs
// @Accept json
// @Produce json
// @Param request body BackupConfig true "Backup configuration data"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 500
// @Param request body BackupConfig true "Backup configuration data (encryption field: NONE or ENCRYPTED)"
// @Success 200 {object} BackupConfig "Returns the saved backup configuration including encryption settings"
// @Failure 400 {object} map[string]string "Invalid encryption value or other validation errors"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 500 {object} map[string]string "Internal server error"
// @Router /backup-configs/save [post]
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -57,14 +57,14 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
// GetBackupConfigByDbID
// @Summary Get backup configuration by database ID
// @Description Get backup configuration for a specific database
// @Description Get backup configuration for a specific database including encryption settings (NONE or ENCRYPTED)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 404
// @Success 200 {object} BackupConfig "Returns backup configuration with encryption field"
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Backup configuration not found"
// @Router /backup-configs/database/{id} [get]
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)

View File

@@ -368,6 +368,86 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
}
}
func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionNone, response.Encryption)
}
func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionEncrypted,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionEncrypted, response.Encryption)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,

View File

@@ -6,3 +6,10 @@ const (
NotificationBackupFailed BackupNotificationType = "BACKUP_FAILED"
NotificationBackupSuccess BackupNotificationType = "BACKUP_SUCCESS"
)
type BackupEncryption string
const (
BackupEncryptionNone BackupEncryption = "NONE"
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
)

View File

@@ -31,6 +31,8 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
CpuCount int `json:"cpuCount" gorm:"type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
}
func (h *BackupConfig) TableName() string {
@@ -88,6 +90,11 @@ func (b *BackupConfig) Validate() error {
return errors.New("max failed tries count must be greater than 0")
}
if b.Encryption != "" && b.Encryption != BackupEncryptionNone &&
b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption must be NONE or ENCRYPTED")
}
return nil
}
@@ -103,5 +110,6 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
CpuCount: b.CpuCount,
Encryption: b.Encryption,
}
}

View File

@@ -171,6 +171,7 @@ func (s *BackupConfigService) initializeDefaultConfig(
CpuCount: 1,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
})
return err

View File

@@ -16,6 +16,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -769,6 +770,71 @@ func createTestDatabaseViaAPI(
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
},
}
var createdDatabase Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&createdDatabase,
)
repository := &DatabaseRepository{}
databaseFromDB, err := repository.FindByID(createdDatabase.ID)
assert.NoError(t, err)
assert.NotNil(t, databaseFromDB)
assert.NotNil(t, databaseFromDB.Postgresql)
assert.True(
t,
strings.HasPrefix(databaseFromDB.Postgresql.Password, "enc:"),
"Password should be encrypted in database with 'enc:' prefix, got: %s",
databaseFromDB.Postgresql.Password,
)
encryptor := encryption.GetFieldEncryptor()
decryptedPassword, err := encryptor.Decrypt(
databaseFromDB.ID,
databaseFromDB.Postgresql.Password,
)
assert.NoError(t, err)
assert.Equal(t, plainPassword, decryptedPassword,
"Decrypted password should match original plaintext password")
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+createdDatabase.ID.String(),
"Bearer "+owner.Token,
http.StatusNoContent,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -815,7 +881,15 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
assert.Equal(t, "original-password-secret", database.Postgresql.Password)
// Verify password is encrypted
assert.True(t, strings.HasPrefix(database.Postgresql.Password, "enc:"),
"Password should be encrypted in database")
// Verify it can be decrypted back to original
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)

View File

@@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"regexp"
"slices"
"time"
"github.com/google/uuid"
@@ -59,11 +59,15 @@ func (p *PostgresqlDatabase) Validate() error {
return nil
}
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
func (p *PostgresqlDatabase) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return testSingleDatabaseConnection(logger, ctx, p)
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
}
func (p *PostgresqlDatabase) HideSensitiveData() {
@@ -87,19 +91,42 @@ func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
}
}
func (p *PostgresqlDatabase) EncryptSensitiveFields(
databaseID uuid.UUID,
encryptor encryption.FieldEncryptor,
) error {
if p.Password != "" {
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
if err != nil {
return err
}
p.Password = encrypted
}
return nil
}
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
func testSingleDatabaseConnection(
logger *slog.Logger,
ctx context.Context,
postgresDb *PostgresqlDatabase,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
// For single database backup, we need to connect to the specific database
if postgresDb.Database == nil || *postgresDb.Database == "" {
return errors.New("database name is required for single database backup (pg_dump)")
}
// Decrypt password if needed
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
// Build connection string for the specific database
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
// Test connection
conn, err := pgx.Connect(ctx, connStr)
@@ -182,7 +209,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
}
// buildConnectionStringForDB builds connection string for specific database
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
sslMode := "disable"
if p.IsHttps {
sslMode = "require"
@@ -192,106 +219,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
p.Host,
p.Port,
p.Username,
p.Password,
password,
dbName,
sslMode,
)
}
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
if len(extensions) == 0 {
return nil
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, error) {
if encryptor == nil {
return password, nil
}
if p.Database == nil || *p.Database == "" {
return errors.New("database name is required for installing extensions")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Build connection string for the specific database
connStr := buildConnectionStringForDB(p, *p.Database)
// Connect to database
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
fmt.Println("failed to close connection: %w", closeErr)
}
}()
// Check which extensions are already installed
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
if err != nil {
return fmt.Errorf("failed to check installed extensions: %w", err)
}
// Install missing extensions
for _, extension := range extensions {
if contains(installedExtensions, string(extension)) {
continue // Extension already installed
}
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
}
}
return nil
}
// getInstalledExtensions queries the database for currently installed extensions
func (p *PostgresqlDatabase) getInstalledExtensions(
ctx context.Context,
conn *pgx.Conn,
) ([]string, error) {
query := "SELECT extname FROM pg_extension"
rows, err := conn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
}
defer rows.Close()
var extensions []string
for rows.Next() {
var extname string
if err := rows.Scan(&extname); err != nil {
return nil, fmt.Errorf("failed to scan extension name: %w", err)
}
extensions = append(extensions, extname)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
}
return extensions, nil
}
// installExtension installs a single PostgreSQL extension
func (p *PostgresqlDatabase) installExtension(
ctx context.Context,
conn *pgx.Conn,
extensionName string,
) error {
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
_, err := conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
}
return nil
}
// contains checks if a string slice contains a specific string
func contains(slice []string, item string) bool {
return slices.Contains(slice, item)
return encryptor.Decrypt(databaseID, password)
}

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -19,6 +20,7 @@ var databaseService = &DatabaseService{
[]DatabaseCopyListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var databaseController = &DatabaseController{

View File

@@ -2,6 +2,7 @@ package databases
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -11,7 +12,11 @@ type DatabaseValidator interface {
}
type DatabaseConnector interface {
TestConnection(logger *slog.Logger) error
TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error
HideSensitiveData()
}

View File

@@ -5,6 +5,7 @@ import (
"log/slog"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -56,14 +57,24 @@ func (d *Database) ValidateUpdate(old, new Database) error {
return nil
}
func (d *Database) TestConnection(logger *slog.Logger) error {
return d.getSpecificDatabase().TestConnection(logger)
func (d *Database) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
if d.Postgresql != nil {
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
}
return nil
}
func (d *Database) Update(incoming *Database) {
d.Name = incoming.Name
d.Type = incoming.Type

View File

@@ -11,6 +11,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -26,6 +27,7 @@ type DatabaseService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *DatabaseService) AddDbCreationListener(
@@ -65,6 +67,10 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
database, err = s.dbRepository.Save(database)
if err != nil {
return nil, err
@@ -118,6 +124,10 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
_, err = s.dbRepository.Save(existingDatabase)
if err != nil {
return err
@@ -250,7 +260,7 @@ func (s *DatabaseService) TestDatabaseConnection(
return errors.New("insufficient permissions to test connection for this database")
}
err = database.TestConnection(s.logger)
err = database.TestConnection(s.logger, s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
database.LastBackupErrorMessage = &lastSaveError
@@ -294,7 +304,7 @@ func (s *DatabaseService) TestDatabaseConnectionDirect(
usingDatabase = database
}
return usingDatabase.TestConnection(s.logger)
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) GetDatabaseByID(

View File

@@ -453,70 +453,6 @@ func Test_CrossWorkspaceSecurity_CannotAccessNotifierFromAnotherWorkspace(t *tes
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -553,7 +489,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-bot-token-12345", notifier.TelegramNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "original-bot-token-12345", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.TelegramNotifier.BotToken)
@@ -592,7 +534,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-password-secret", notifier.EmailNotifier.SMTPPassword)
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.EmailNotifier.SMTPPassword)
@@ -625,7 +573,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "xoxb-original-slack-token", notifier.SlackNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "xoxb-original-slack-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.SlackNotifier.BotToken)
@@ -656,11 +610,17 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(
assert.True(
t,
"https://discord.com/api/webhooks/123/original-token",
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/original-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.DiscordNotifier.ChannelWebhookURL)
@@ -691,10 +651,16 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(
t,
"https://outlook.office.com/webhook/original-token",
notifier.TeamsNotifier.WebhookURL,
decrypted,
)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
@@ -813,3 +779,263 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
testCases := []struct {
name string
createNotifier func(workspaceID uuid.UUID) *Notifier
verifySensitiveEncryption func(t *testing.T, notifier *Notifier)
}{
{
name: "Telegram Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram",
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: "plain-telegram-token-123",
TargetChatID: "123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "plain-telegram-token-123", decrypted)
},
},
{
name: "Email Notifier - SMTPPassword encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Email",
NotifierType: NotifierTypeEmail,
EmailNotifier: &email_notifier.EmailNotifier{
TargetEmail: "test@example.com",
SMTPHost: "smtp.example.com",
SMTPPort: 587,
SMTPUser: "user@example.com",
SMTPPassword: "plain-smtp-password-456",
From: "noreply@example.com",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "plain-smtp-password-456", decrypted)
},
},
{
name: "Slack Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Slack",
NotifierType: NotifierTypeSlack,
SlackNotifier: &slack_notifier.SlackNotifier{
BotToken: "plain-slack-token-789",
TargetChatID: "C0123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "plain-slack-token-789", decrypted)
},
},
{
name: "Discord Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Discord",
NotifierType: NotifierTypeDiscord,
DiscordNotifier: &discord_notifier.DiscordNotifier{
ChannelWebhookURL: "https://discord.com/api/webhooks/123/abc",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/abc", decrypted)
},
},
{
name: "Teams Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Teams",
NotifierType: NotifierTypeTeams,
TeamsNotifier: &teams_notifier.TeamsNotifier{
WebhookURL: "https://outlook.office.com/webhook/test123",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(t, "https://outlook.office.com/webhook/test123", decrypted)
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Webhook",
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Create notifier via API (plaintext credentials)
var createdNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
tc.createNotifier(workspace.ID),
http.StatusOK,
&createdNotifier,
)
// Read from DB directly (bypass service layer)
repository := &NotifierRepository{}
notifierFromDB, err := repository.FindByID(createdNotifier.ID)
assert.NoError(t, err)
// Verify encryption
tc.verifySensitiveEncryption(t, notifierFromDB)
// Cleanup
deleteNotifier(t, router, createdNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func isEncrypted(value string) bool {
return len(value) > 4 && value[:4] == "enc:"
}
func decryptField(t *testing.T, notifierID uuid.UUID, encryptedValue string) string {
encryptor := GetNotifierService().fieldEncryptor
decrypted, err := encryptor.Decrypt(notifierID, encryptedValue)
assert.NoError(t, err)
return decrypted
}

View File

@@ -3,6 +3,7 @@ package notifiers
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -12,6 +13,7 @@ var notifierService = &NotifierService{
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var notifierController = &NotifierController{
notifierService,

View File

@@ -1,11 +1,21 @@
package notifiers
import "log/slog"
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
)
type NotificationSender interface {
Send(logger *slog.Logger, heading string, message string) error
Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -9,6 +9,7 @@ import (
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
return "notifiers"
}
func (n *Notifier) Validate() error {
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.Name == "" {
return errors.New("name is required")
}
return n.getSpecificNotifier().Validate()
return n.getSpecificNotifier().Validate(encryptor)
}
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
err := n.getSpecificNotifier().Send(logger, heading, message)
func (n *Notifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
if err != nil {
lastSendError := err.Error()
@@ -58,6 +64,10 @@ func (n *Notifier) HideSensitiveData() {
n.getSpecificNotifier().HideSensitiveData()
}
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
}
func (n *Notifier) Update(incoming *Notifier) {
n.Name = incoming.Name
n.NotifierType = incoming.NotifierType

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
return "discord_notifiers"
}
func (d *DiscordNotifier) Validate() error {
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
return nil
}
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (d *DiscordNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
@@ -81,3 +92,14 @@ func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
d.ChannelWebhookURL = incoming.ChannelWebhookURL
}
}
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL != "" {
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
d.ChannelWebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net"
"net/smtp"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -34,7 +35,7 @@ func (e *EmailNotifier) TableName() string {
return "email_notifiers"
}
func (e *EmailNotifier) Validate() error {
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if e.TargetEmail == "" {
return errors.New("target email is required")
}
@@ -55,7 +56,22 @@ func (e *EmailNotifier) Validate() error {
return nil
}
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (e *EmailNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
// Decrypt SMTP password if provided
var smtpPassword string
if e.SMTPPassword != "" {
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
}
smtpPassword = decrypted
}
// Compose email
from := e.From
if from == "" {
@@ -85,7 +101,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
timeout := DefaultTimeout
// Determine if authentication is required
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
// Handle different port scenarios
if e.SMTPPort == ImplicitTLSPort {
@@ -116,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Set up authentication only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -179,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Authenticate only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -229,3 +245,14 @@ func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.SMTPPassword = incoming.SMTPPassword
}
}
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if e.SMTPPassword != "" {
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
}
e.SMTPPassword = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
"time"
@@ -23,7 +24,7 @@ type SlackNotifier struct {
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
func (s *SlackNotifier) Validate() error {
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if s.BotToken == "" {
return errors.New("bot token is required")
}
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
return nil
}
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
func (s *SlackNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
full := fmt.Sprintf("*%s*", heading)
if message != "" {
@@ -80,7 +90,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Authorization", "Bearer "+s.BotToken)
req.Header.Set("Authorization", "Bearer "+botToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@@ -144,3 +154,14 @@ func (s *SlackNotifier) Update(incoming *SlackNotifier) {
s.BotToken = incoming.BotToken
}
}
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if s.BotToken != "" {
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
s.BotToken = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
return "teams_notifiers"
}
func (n *TeamsNotifier) Validate() error {
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL == "" {
return errors.New("webhook_url is required")
}
u, err := url.Parse(n.WebhookURL)
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
u, err := url.Parse(webhookURL)
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
return errors.New("invalid webhook_url")
}
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
}
type cardAttachment struct {
ContentType string `json:"contentType"`
Content interface{} `json:"content"`
ContentType string `json:"contentType"`
Content any `json:"content"`
}
type payload struct {
@@ -43,11 +50,20 @@ type payload struct {
Attachments []cardAttachment `json:"attachments,omitempty"`
}
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
if err := n.Validate(); err != nil {
func (n *TeamsNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
if err := n.Validate(encryptor); err != nil {
return err
}
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
card := map[string]any{
"type": "AdaptiveCard",
"version": "1.4",
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
}
body, _ := json.Marshal(p)
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return err
}
@@ -104,3 +120,14 @@ func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
n.WebhookURL = incoming.WebhookURL
}
}
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
n.WebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
return "telegram_notifiers"
}
func (t *TelegramNotifier) Validate() error {
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.BotToken == "" {
return errors.New("bot token is required")
}
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
return nil
}
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *TelegramNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
}
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
data := url.Values{}
data.Set("chat_id", t.TargetChatID)
@@ -93,3 +104,14 @@ func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
t.BotToken = incoming.BotToken
}
}
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.BotToken != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
t.BotToken = encrypted
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
return "webhook_notifiers"
}
func (t *WebhookNotifier) Validate() error {
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
return nil
}
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *WebhookNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
switch t.WebhookMethod {
case WebhookMethodGET:
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
t.WebhookURL,
webhookURL,
url.QueryEscape(heading),
url.QueryEscape(message),
)
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to send POST webhook: %w", err)
}
@@ -110,3 +121,14 @@ func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
t.WebhookURL = incoming.WebhookURL
t.WebhookMethod = incoming.WebhookMethod
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
t.WebhookURL = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -17,6 +18,7 @@ type NotifierService struct {
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *NotifierService) SaveNotifier(
@@ -46,7 +48,11 @@ func (s *NotifierService) SaveNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -63,7 +69,11 @@ func (s *NotifierService) SaveNotifier(
} else {
notifier.WorkspaceID = workspaceID
if err := notifier.Validate(); err != nil {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := notifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -175,7 +185,7 @@ func (s *NotifierService) SendTestNotification(
return errors.New("insufficient permissions to test notifier in this workspace")
}
err = notifier.Send(s.logger, "Test message", "This is a test message")
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
if err != nil {
return err
}
@@ -205,16 +215,24 @@ func (s *NotifierService) SendTestNotificationToNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = existingNotifier
} else {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = notifier
}
return usingNotifier.Send(s.logger, "Test message", "This is a test message")
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
}
func (s *NotifierService) SendNotification(
@@ -233,7 +251,7 @@ func (s *NotifierService) SendNotification(
return
}
err = notifiedFromDb.Send(s.logger, title, message)
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
if err != nil {
errMsg := err.Error()
notifiedFromDb.LastSendError = &errMsg

View File

@@ -29,6 +29,7 @@ import (
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
util_encryption "postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -309,6 +310,7 @@ func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -338,7 +340,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -1,11 +1,13 @@
package usecases_postgresql
import (
users_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/util/logger"
)
var restorePostgresqlBackupUsecase = &RestorePostgresqlBackupUsecase{
logger.GetLogger(),
users_repositories.GetSecretKeyRepository(),
}
func GetRestorePostgresqlBackupUsecase() *RestorePostgresqlBackupUsecase {

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -15,11 +16,14 @@ import (
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
users_repositories "postgresus-backend/internal/features/users/repositories"
util_encryption "postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/tools"
@@ -27,7 +31,8 @@ import (
)
type RestorePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyRepo *users_repositories.SecretKeyRepository
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
@@ -202,18 +207,67 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
backup.ID,
"tempFile",
tempBackupFile,
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
backupReader, err := storage.GetFile(backup.ID)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := backupReader.Close(); err != nil {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
// Create a reader that handles decryption if needed
var backupReader io.Reader = rawReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
// Validate encryption metadata
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
cleanupFunc()
return "", nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := uc.secretKeyRepo.GetSecretKey()
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get master key for decryption: %w", err)
}
// Decode salt and IV from base64
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption IV: %w", err)
}
// Create decryption reader
decryptReader, err := encryption.NewDecryptionReader(
rawReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create decryption reader: %w", err)
}
backupReader = decryptReader
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
}
// Create temporary backup file
tempFile, err := os.Create(tempBackupFile)
if err != nil {

View File

@@ -3,10 +3,14 @@ package storages
import (
"fmt"
"net/http"
"strings"
"testing"
audit_logs "postgresus-backend/internal/features/audit_logs"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
@@ -14,6 +18,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
@@ -438,6 +443,386 @@ func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *test
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
"S3AccessKey should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
"S3SecretKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
assert.NoError(t, err)
assert.Equal(t, "original-access-key", accessKey)
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
assert.NoError(t, err)
assert.Equal(t, "original-secret-key", secretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
{
name: "NAS Storage",
storageType: StorageTypeNAS,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Test NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas.example.com",
Port: 445,
Share: "backups",
Username: "testuser",
Password: "original-password",
UseSSL: false,
Domain: "WORKGROUP",
Path: "/test",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Updated NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas2.example.com",
Port: 445,
Share: "backups2",
Username: "testuser2",
Password: "",
UseSSL: true,
Domain: "WORKGROUP2",
Path: "/test2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.NASStorage.Password)
},
},
{
name: "Azure Blob Storage (Connection String)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "original-connection-string",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
"ConnectionString should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
connectionString, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.ConnectionString,
)
assert.NoError(t, err)
assert.Equal(t, "original-connection-string", connectionString)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Azure Blob Storage (Account Key)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "testaccount",
AccountKey: "original-account-key",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "updatedaccount",
AccountKey: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
"AccountKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accountKey, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.AccountKey,
)
assert.NoError(t, err)
assert.Equal(t, "original-account-key", accountKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Google Drive Storage",
storageType: StorageTypeGoogleDrive,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Test Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Updated Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "updated-client-id",
ClientSecret: "",
TokenJSON: "",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
"ClientSecret should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
"TokenJSON should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
clientSecret, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.ClientSecret,
)
assert.NoError(t, err)
assert.Equal(t, "original-client-secret", clientSecret)
tokenJSON, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.TokenJSON,
)
assert.NoError(t, err)
assert.Equal(
t,
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
tokenJSON,
)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Verify sensitive data is encrypted in repository after creation
repository := &StorageRepository{}
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
tc.verifySensitiveData(t, storageFromDBAfterCreate)
// Phase 3: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -485,158 +870,3 @@ func deleteStorage(
http.StatusOK,
)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "original-access-key", storage.S3Storage.S3AccessKey)
assert.Equal(t, "original-secret-key", storage.S3Storage.S3SecretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &StorageRepository{}
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}

View File

@@ -3,6 +3,7 @@ package storages
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
)
var storageRepository = &StorageRepository{}
@@ -10,6 +11,7 @@ var storageService = &StorageService{
storageRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var storageController = &StorageController{
storageService,

View File

@@ -7,4 +7,5 @@ const (
StorageTypeS3 StorageType = "S3"
StorageTypeGoogleDrive StorageType = "GOOGLE_DRIVE"
StorageTypeNAS StorageType = "NAS"
StorageTypeAzureBlob StorageType = "AZURE_BLOB"
)

View File

@@ -3,20 +3,28 @@ package storages
import (
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
type StorageFileSaver interface {
SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error
SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error
GetFile(fileID uuid.UUID) (io.ReadCloser, error)
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
DeleteFile(fileID uuid.UUID) error
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
TestConnection() error
TestConnection(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -4,10 +4,12 @@ import (
"errors"
"io"
"log/slog"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -24,10 +26,16 @@ type Storage struct {
S3Storage *s3_storage.S3Storage `json:"s3Storage" gorm:"foreignKey:StorageID"`
GoogleDriveStorage *google_drive_storage.GoogleDriveStorage `json:"googleDriveStorage" gorm:"foreignKey:StorageID"`
NASStorage *nas_storage.NASStorage `json:"nasStorage" gorm:"foreignKey:StorageID"`
AzureBlobStorage *azure_blob_storage.AzureBlobStorage `json:"azureBlobStorage" gorm:"foreignKey:StorageID"`
}
func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
err := s.getSpecificStorage().SaveFile(logger, fileID, file)
func (s *Storage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
err := s.getSpecificStorage().SaveFile(encryptor, logger, fileID, file)
if err != nil {
lastSaveError := err.Error()
s.LastSaveError = &lastSaveError
@@ -39,15 +47,18 @@ func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader
return nil
}
func (s *Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(fileID)
func (s *Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(encryptor, fileID)
}
func (s *Storage) DeleteFile(fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(fileID)
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
}
func (s *Storage) Validate() error {
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.Type == "" {
return errors.New("storage type is required")
}
@@ -56,17 +67,21 @@ func (s *Storage) Validate() error {
return errors.New("storage name is required")
}
return s.getSpecificStorage().Validate()
return s.getSpecificStorage().Validate(encryptor)
}
func (s *Storage) TestConnection() error {
return s.getSpecificStorage().TestConnection()
func (s *Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().TestConnection(encryptor)
}
func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
}
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type
@@ -88,6 +103,10 @@ func (s *Storage) Update(incoming *Storage) {
if s.NASStorage != nil && incoming.NASStorage != nil {
s.NASStorage.Update(incoming.NASStorage)
}
case StorageTypeAzureBlob:
if s.AzureBlobStorage != nil && incoming.AzureBlobStorage != nil {
s.AzureBlobStorage.Update(incoming.AzureBlobStorage)
}
}
}
@@ -101,6 +120,8 @@ func (s *Storage) getSpecificStorage() StorageFileSaver {
return s.GoogleDriveStorage
case StorageTypeNAS:
return s.NASStorage
case StorageTypeAzureBlob:
return s.AzureBlobStorage
default:
panic("invalid storage type: " + string(s.Type))
}

View File

@@ -8,15 +8,18 @@ import (
"os"
"path/filepath"
"postgresus-backend/internal/config"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strconv"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/google/uuid"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
@@ -32,6 +35,15 @@ type S3Container struct {
region string
}
type AzuriteContainer struct {
endpoint string
accountName string
accountKey string
containerNameKey string
containerNameStr string
connectionString string
}
func Test_Storage_BasicOperations(t *testing.T) {
ctx := context.Background()
@@ -41,6 +53,10 @@ func Test_Storage_BasicOperations(t *testing.T) {
s3Container, err := setupS3Container(ctx)
require.NoError(t, err, "Failed to setup S3 container")
// Setup Azurite connection
azuriteContainer, err := setupAzuriteContainer(ctx)
require.NoError(t, err, "Failed to setup Azurite container")
// Setup test file
testFilePath, err := setupTestFile()
require.NoError(t, err, "Failed to setup test file")
@@ -88,6 +104,26 @@ func Test_Storage_BasicOperations(t *testing.T) {
Path: "test-files",
},
},
{
name: "AzureBlobStorage_AccountKey",
storage: &azure_blob_storage.AzureBlobStorage{
StorageID: uuid.New(),
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: azuriteContainer.accountName,
AccountKey: azuriteContainer.accountKey,
ContainerName: azuriteContainer.containerNameKey,
Endpoint: azuriteContainer.endpoint,
},
},
{
name: "AzureBlobStorage_ConnectionString",
storage: &azure_blob_storage.AzureBlobStorage{
StorageID: uuid.New(),
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: azuriteContainer.connectionString,
ContainerName: azuriteContainer.containerNameStr,
},
},
}
// Add Google Drive storage test only if environment variables are available
@@ -112,13 +148,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
t.Run("Test_TestConnection_ConnectionSucceeds", func(t *testing.T) {
err := tc.storage.TestConnection()
err := tc.storage.TestConnection(encryptor)
assert.NoError(t, err, "TestConnection should succeed")
})
t.Run("Test_TestValidation_ValidationSucceeds", func(t *testing.T) {
err := tc.storage.Validate()
err := tc.storage.Validate(encryptor)
assert.NoError(t, err, "Validate should succeed")
})
@@ -128,10 +166,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.NoError(t, err, "GetFile should succeed")
defer file.Close()
@@ -145,13 +188,18 @@ func Test_Storage_BasicOperations(t *testing.T) {
require.NoError(t, err, "Should be able to read test file")
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
err = tc.storage.DeleteFile(fileID)
err = tc.storage.DeleteFile(encryptor, fileID)
assert.NoError(t, err, "DeleteFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.Error(t, err, "GetFile should fail for non-existent file")
if file != nil {
file.Close()
@@ -161,7 +209,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
// Try to delete a non-existent file
nonExistentID := uuid.New()
err := tc.storage.DeleteFile(nonExistentID)
err := tc.storage.DeleteFile(encryptor, nonExistentID)
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
})
})
@@ -230,8 +278,59 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
}, nil
}
func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
env := config.GetEnv()
accountName := "devstoreaccount1"
// this is real testing key for azurite, it's not a real key
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
containerNameKey := "test-container-key"
containerNameStr := "test-container-connstr"
// Build explicit connection string for Azurite
connectionString := fmt.Sprintf(
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
accountName,
accountKey,
env.TestAzuriteBlobPort,
accountName,
)
// Create client using connection string to set up containers
client, err := azblob.NewClientFromConnectionString(connectionString, nil)
if err != nil {
return nil, fmt.Errorf("failed to create azblob client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Create container for account key auth
_, err = client.CreateContainer(ctx, containerNameKey, nil)
if err != nil {
// Container might already exist, that's okay
}
// Create container for connection string auth
_, err = client.CreateContainer(ctx, containerNameStr, nil)
if err != nil {
// Container might already exist, that's okay
}
return &AzuriteContainer{
endpoint: serviceURL,
accountName: accountName,
accountKey: accountKey,
containerNameKey: containerNameKey,
containerNameStr: containerNameStr,
connectionString: connectionString,
}, nil
}
func validateEnvVariables(t *testing.T) {
env := config.GetEnv()
assert.NotEmpty(t, env.TestMinioPort, "TEST_MINIO_PORT is empty")
assert.NotEmpty(t, env.TestAzuriteBlobPort, "TEST_AZURITE_BLOB_PORT is empty")
assert.NotEmpty(t, env.TestNASPort, "TEST_NAS_PORT is empty")
}

View File

@@ -0,0 +1,303 @@
package azure_blob_storage
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/google/uuid"
)
type AuthMethod string
const (
AuthMethodConnectionString AuthMethod = "CONNECTION_STRING"
AuthMethodAccountKey AuthMethod = "ACCOUNT_KEY"
)
type AzureBlobStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
AuthMethod AuthMethod `json:"authMethod" gorm:"not null;type:text;column:auth_method"`
ConnectionString string `json:"connectionString" gorm:"type:text;column:connection_string"`
AccountName string `json:"accountName" gorm:"type:text;column:account_name"`
AccountKey string `json:"accountKey" gorm:"type:text;column:account_key"`
ContainerName string `json:"containerName" gorm:"not null;type:text;column:container_name"`
Endpoint string `json:"endpoint" gorm:"type:text;column:endpoint"`
Prefix string `json:"prefix" gorm:"type:text;column:prefix"`
}
func (s *AzureBlobStorage) TableName() string {
return "azure_blob_storages"
}
func (s *AzureBlobStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
_, err = client.UploadStream(
context.TODO(),
s.ContainerName,
blobName,
file,
nil,
)
if err != nil {
return fmt.Errorf("failed to upload blob to Azure: %w", err)
}
return nil
}
func (s *AzureBlobStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
blobName := s.buildBlobName(fileID.String())
response, err := client.DownloadStream(
context.TODO(),
s.ContainerName,
blobName,
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to download blob from Azure: %w", err)
}
return response.Body, nil
}
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
_, err = client.DeleteBlob(
context.TODO(),
s.ContainerName,
blobName,
nil,
)
if err != nil {
var respErr *azcore.ResponseError
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed to delete blob from Azure: %w", err)
}
return nil
}
func (s *AzureBlobStorage) Validate(encryptor encryption.FieldEncryptor) error {
if s.ContainerName == "" {
return errors.New("container name is required")
}
switch s.AuthMethod {
case AuthMethodConnectionString:
if s.ConnectionString == "" {
return errors.New(
"connection string is required when using CONNECTION_STRING auth method",
)
}
case AuthMethodAccountKey:
if s.AccountName == "" {
return errors.New("account name is required when using ACCOUNT_KEY auth method")
}
if s.AccountKey == "" {
return errors.New("account key is required when using ACCOUNT_KEY auth method")
}
default:
return fmt.Errorf("invalid auth method: %s", s.AuthMethod)
}
return nil
}
func (s *AzureBlobStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
containerClient := client.ServiceClient().NewContainerClient(s.ContainerName)
_, err = containerClient.GetProperties(ctx, nil)
if err != nil {
var respErr *azcore.ResponseError
if errors.As(err, &respErr) {
if respErr.StatusCode == 404 {
return fmt.Errorf("container '%s' does not exist", s.ContainerName)
}
}
if errors.Is(err, context.DeadlineExceeded) {
return errors.New("failed to connect to Azure Blob Storage. Please check params")
}
return fmt.Errorf("failed to connect to Azure Blob Storage: %w", err)
}
testBlobName := s.buildBlobName(uuid.New().String() + "-test")
testData := []byte("test connection")
_, err = client.UploadStream(
ctx,
s.ContainerName,
testBlobName,
bytes.NewReader(testData),
nil,
)
if err != nil {
return fmt.Errorf("failed to upload test blob to Azure: %w", err)
}
_, err = client.DeleteBlob(
ctx,
s.ContainerName,
testBlobName,
nil,
)
if err != nil {
return fmt.Errorf("failed to delete test blob from Azure: %w", err)
}
return nil
}
func (s *AzureBlobStorage) HideSensitiveData() {
s.ConnectionString = ""
s.AccountKey = ""
}
func (s *AzureBlobStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ConnectionString != "" {
s.ConnectionString, err = encryptor.Encrypt(s.StorageID, s.ConnectionString)
if err != nil {
return fmt.Errorf("failed to encrypt Azure connection string: %w", err)
}
}
if s.AccountKey != "" {
s.AccountKey, err = encryptor.Encrypt(s.StorageID, s.AccountKey)
if err != nil {
return fmt.Errorf("failed to encrypt Azure account key: %w", err)
}
}
return nil
}
func (s *AzureBlobStorage) Update(incoming *AzureBlobStorage) {
s.AuthMethod = incoming.AuthMethod
s.ContainerName = incoming.ContainerName
s.Endpoint = incoming.Endpoint
if incoming.ConnectionString != "" {
s.ConnectionString = incoming.ConnectionString
}
if incoming.AccountName != "" {
s.AccountName = incoming.AccountName
}
if incoming.AccountKey != "" {
s.AccountKey = incoming.AccountKey
}
}
func (s *AzureBlobStorage) buildBlobName(fileName string) string {
if s.Prefix == "" {
return fileName
}
prefix := s.Prefix
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
}
return prefix + fileName
}
func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azblob.Client, error) {
var client *azblob.Client
var err error
switch s.AuthMethod {
case AuthMethodConnectionString:
connectionString, decryptErr := encryptor.Decrypt(s.StorageID, s.ConnectionString)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure connection string: %w", decryptErr)
}
client, err = azblob.NewClientFromConnectionString(connectionString, nil)
if err != nil {
return nil, fmt.Errorf(
"failed to create Azure Blob client from connection string: %w",
err,
)
}
case AuthMethodAccountKey:
accountKey, decryptErr := encryptor.Decrypt(s.StorageID, s.AccountKey)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure account key: %w", decryptErr)
}
accountURL := s.buildAccountURL()
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, accountKey)
if credErr != nil {
return nil, fmt.Errorf("failed to create Azure shared key credential: %w", credErr)
}
client, err = azblob.NewClientWithSharedKeyCredential(accountURL, credential, nil)
if err != nil {
return nil, fmt.Errorf("failed to create Azure Blob client with shared key: %w", err)
}
default:
return nil, fmt.Errorf("unsupported auth method: %s", s.AuthMethod)
}
return client, nil
}
func (s *AzureBlobStorage) buildAccountURL() string {
if s.Endpoint != "" {
endpoint := s.Endpoint
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
endpoint = "https://" + endpoint
}
return endpoint
}
return fmt.Sprintf("https://%s.blob.core.windows.net/", s.AccountName)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -30,11 +31,12 @@ func (s *GoogleDriveStorage) TableName() string {
}
func (s *GoogleDriveStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
filename := fileID.String()
@@ -68,9 +70,12 @@ func (s *GoogleDriveStorage) SaveFile(
})
}
func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
func (s *GoogleDriveStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
var result io.ReadCloser
err := s.withRetryOnAuth(func(driveService *drive.Service) error {
err := s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
return fmt.Errorf("failed to find backups folder: %w", err)
@@ -93,8 +98,11 @@ func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return result, err
}
func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
@@ -105,7 +113,7 @@ func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
})
}
func (s *GoogleDriveStorage) Validate() error {
func (s *GoogleDriveStorage) Validate(encryptor encryption.FieldEncryptor) error {
switch {
case s.ClientID == "":
return errors.New("client ID is required")
@@ -115,7 +123,12 @@ func (s *GoogleDriveStorage) Validate() error {
return errors.New("token JSON is required")
}
// Also validate that the token JSON contains a refresh token
// Skip JSON validation if token is already encrypted
if strings.HasPrefix(s.TokenJSON, "enc:") {
return nil
}
// Validate that the token JSON contains a refresh token
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON format: %w", err)
@@ -128,8 +141,8 @@ func (s *GoogleDriveStorage) Validate() error {
return nil
}
func (s *GoogleDriveStorage) TestConnection() error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
func (s *GoogleDriveStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
ctx := context.Background()
testFilename := "test-connection-" + uuid.New().String()
testData := []byte("test")
@@ -196,6 +209,26 @@ func (s *GoogleDriveStorage) HideSensitiveData() {
s.TokenJSON = ""
}
func (s *GoogleDriveStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ClientSecret != "" {
s.ClientSecret, err = encryptor.Encrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive client secret: %w", err)
}
}
if s.TokenJSON != "" {
s.TokenJSON, err = encryptor.Encrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive token JSON: %w", err)
}
}
return nil
}
func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
s.ClientID = incoming.ClientID
@@ -209,8 +242,11 @@ func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
}
// withRetryOnAuth executes the provided function with retry logic for authentication errors
func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) error {
driveService, err := s.getDriveService()
func (s *GoogleDriveStorage) withRetryOnAuth(
encryptor encryption.FieldEncryptor,
fn func(*drive.Service) error,
) error {
driveService, err := s.getDriveService(encryptor)
if err != nil {
return err
}
@@ -220,7 +256,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
// Try to refresh token and retry once
fmt.Printf("Google Drive auth error detected, attempting token refresh: %v\n", err)
if refreshErr := s.refreshToken(); refreshErr != nil {
if refreshErr := s.refreshToken(encryptor); refreshErr != nil {
// If refresh fails, return a more helpful error message
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
strings.Contains(refreshErr.Error(), "refresh token") {
@@ -237,7 +273,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
fmt.Printf("Token refresh successful, retrying operation\n")
// Get new service with refreshed token
driveService, err = s.getDriveService()
driveService, err = s.getDriveService(encryptor)
if err != nil {
return fmt.Errorf("failed to create service after token refresh: %w", err)
}
@@ -268,13 +304,24 @@ func (s *GoogleDriveStorage) isAuthError(err error) bool {
}
// refreshToken refreshes the OAuth2 token and updates the TokenJSON field
func (s *GoogleDriveStorage) refreshToken() error {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) error {
if err := s.Validate(encryptor); err != nil {
return err
}
// Decrypt credentials before use
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON: %w", err)
}
@@ -289,12 +336,12 @@ func (s *GoogleDriveStorage) refreshToken() error {
token.Expiry)
// Debug: Print the full token JSON structure (sensitive data masked)
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(s.TokenJSON))
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(tokenJSON))
ctx := context.Background()
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}
@@ -330,7 +377,7 @@ func (s *GoogleDriveStorage) refreshToken() error {
newToken.RefreshToken = token.RefreshToken
}
// Update the stored token JSON
// Update the stored token JSON (keep as plaintext in memory, encryption happens on save)
newTokenJSON, err := json.Marshal(newToken)
if err != nil {
return fmt.Errorf("failed to marshal refreshed token: %w", err)
@@ -368,13 +415,26 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen]
}
func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) getDriveService(
encryptor encryption.FieldEncryptor,
) (*drive.Service, error) {
if err := s.Validate(encryptor); err != nil {
return nil, err
}
// Decrypt credentials before use
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("invalid token JSON: %w", err)
}
@@ -382,7 +442,7 @@ func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}

View File

@@ -7,6 +7,7 @@ import (
"os"
"path/filepath"
"postgresus-backend/internal/config"
"postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"github.com/google/uuid"
@@ -23,7 +24,12 @@ func (l *LocalStorage) TableName() string {
return "local_storages"
}
func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (l *LocalStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
err := files_utils.EnsureDirectories([]string{
@@ -107,7 +113,10 @@ func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.R
return nil
}
func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
func (l *LocalStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -122,7 +131,7 @@ func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return file, nil
}
func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -136,11 +145,11 @@ func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (l *LocalStorage) Validate() error {
func (l *LocalStorage) Validate(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
testFile := filepath.Join(config.GetEnv().TempFolder, "test_connection")
f, err := os.Create(testFile)
if err != nil {
@@ -160,5 +169,9 @@ func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) HideSensitiveData() {
}
func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) Update(incoming *LocalStorage) {
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net"
"path/filepath"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -31,10 +32,15 @@ func (n *NASStorage) TableName() string {
return "nas_storages"
}
func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (n *NASStorage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
session, err := n.createSession()
session, err := n.createSession(encryptor)
if err != nil {
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to create NAS session: %w", err)
@@ -131,8 +137,11 @@ func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Rea
return nil
}
func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
session, err := n.createSession()
func (n *NASStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
session, err := n.createSession(encryptor)
if err != nil {
return nil, fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -168,8 +177,8 @@ func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
}, nil
}
func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
session, err := n.createSession()
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -202,7 +211,7 @@ func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (n *NASStorage) Validate() error {
func (n *NASStorage) Validate(encryptor encryption.FieldEncryptor) error {
if n.Host == "" {
return errors.New("NAS host is required")
}
@@ -219,12 +228,11 @@ func (n *NASStorage) Validate() error {
return errors.New("NAS port must be between 1 and 65535")
}
// Test the configuration by creating a session
return n.TestConnection()
return nil
}
func (n *NASStorage) TestConnection() error {
session, err := n.createSession()
func (n *NASStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to connect to NAS: %w", err)
}
@@ -255,6 +263,18 @@ func (n *NASStorage) HideSensitiveData() {
n.Password = ""
}
func (n *NASStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.Password != "" {
encrypted, err := encryptor.Encrypt(n.StorageID, n.Password)
if err != nil {
return fmt.Errorf("failed to encrypt NAS password: %w", err)
}
n.Password = encrypted
}
return nil
}
func (n *NASStorage) Update(incoming *NASStorage) {
n.Host = incoming.Host
n.Port = incoming.Port
@@ -269,18 +289,25 @@ func (n *NASStorage) Update(incoming *NASStorage) {
}
}
func (n *NASStorage) createSession() (*smb2.Session, error) {
func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.Session, error) {
// Create connection with timeout
conn, err := n.createConnection()
if err != nil {
return nil, err
}
// Decrypt password before use
password, err := encryptor.Decrypt(n.StorageID, n.Password)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to decrypt NAS password: %w", err)
}
// Create SMB2 dialer
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: n.Username,
Password: n.Password,
Password: password,
Domain: n.Domain,
},
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -22,23 +23,33 @@ type S3Storage struct {
S3AccessKey string `json:"s3AccessKey" gorm:"not null;type:text;column:s3_access_key"`
S3SecretKey string `json:"s3SecretKey" gorm:"not null;type:text;column:s3_secret_key"`
S3Endpoint string `json:"s3Endpoint" gorm:"type:text;column:s3_endpoint"`
S3Prefix string `json:"s3Prefix" gorm:"type:text;column:s3_prefix"`
S3UseVirtualHostedStyle bool `json:"s3UseVirtualHostedStyle" gorm:"default:false;column:s3_use_virtual_hosted_style"`
}
func (s *S3Storage) TableName() string {
return "s3_storages"
}
func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
client, err := s.getClient()
func (s *S3Storage) SaveFile(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
objectKey := s.buildObjectKey(fileID.String())
// Upload the file using MinIO client with streaming (size = -1 for unknown size)
_, err = client.PutObject(
context.TODO(),
s.S3Bucket,
fileID.String(),
objectKey,
file,
-1,
minio.PutObjectOptions{},
@@ -50,16 +61,21 @@ func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Read
return nil
}
func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
client, err := s.getClient()
func (s *S3Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
objectKey := s.buildObjectKey(fileID.String())
object, err := client.GetObject(
context.TODO(),
s.S3Bucket,
fileID.String(),
objectKey,
minio.GetObjectOptions{},
)
if err != nil {
@@ -84,17 +100,19 @@ func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return object, nil
}
func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
client, err := s.getClient()
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
objectKey := s.buildObjectKey(fileID.String())
// Delete the object using MinIO client
err = client.RemoveObject(
context.TODO(),
s.S3Bucket,
fileID.String(),
objectKey,
minio.RemoveObjectOptions{},
)
if err != nil {
@@ -104,7 +122,7 @@ func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (s *S3Storage) Validate() error {
func (s *S3Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.S3Bucket == "" {
return errors.New("S3 bucket is required")
}
@@ -115,17 +133,11 @@ func (s *S3Storage) Validate() error {
return errors.New("S3 secret key is required")
}
// Try to create a client to validate the configuration
_, err := s.getClient()
if err != nil {
return fmt.Errorf("invalid S3 configuration: %w", err)
}
return nil
}
func (s *S3Storage) TestConnection() error {
client, err := s.getClient()
func (s *S3Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -150,6 +162,7 @@ func (s *S3Storage) TestConnection() error {
// Test write and delete permissions by uploading and removing a small test file
testFileID := uuid.New().String() + "-test"
testObjectKey := s.buildObjectKey(testFileID)
testData := []byte("test connection")
testReader := bytes.NewReader(testData)
@@ -157,7 +170,7 @@ func (s *S3Storage) TestConnection() error {
_, err = client.PutObject(
ctx,
s.S3Bucket,
testFileID,
testObjectKey,
testReader,
int64(len(testData)),
minio.PutObjectOptions{},
@@ -170,7 +183,7 @@ func (s *S3Storage) TestConnection() error {
err = client.RemoveObject(
ctx,
s.S3Bucket,
testFileID,
testObjectKey,
minio.RemoveObjectOptions{},
)
if err != nil {
@@ -185,10 +198,31 @@ func (s *S3Storage) HideSensitiveData() {
s.S3SecretKey = ""
}
func (s *S3Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.S3AccessKey != "" {
s.S3AccessKey, err = encryptor.Encrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 access key: %w", err)
}
}
if s.S3SecretKey != "" {
s.S3SecretKey, err = encryptor.Encrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 secret key: %w", err)
}
}
return nil
}
func (s *S3Storage) Update(incoming *S3Storage) {
s.S3Bucket = incoming.S3Bucket
s.S3Region = incoming.S3Region
s.S3Endpoint = incoming.S3Endpoint
s.S3UseVirtualHostedStyle = incoming.S3UseVirtualHostedStyle
if incoming.S3AccessKey != "" {
s.S3AccessKey = incoming.S3AccessKey
@@ -197,9 +231,27 @@ func (s *S3Storage) Update(incoming *S3Storage) {
if incoming.S3SecretKey != "" {
s.S3SecretKey = incoming.S3SecretKey
}
// we do not allow to change the prefix after creation,
// otherwise we will have to migrate all the data to the new prefix
}
func (s *S3Storage) getClient() (*minio.Client, error) {
func (s *S3Storage) buildObjectKey(fileName string) string {
if s.S3Prefix == "" {
return fileName
}
prefix := s.S3Prefix
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
}
return prefix + fileName
}
func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Client, error) {
endpoint := s.S3Endpoint
useSSL := true
@@ -215,11 +267,29 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", s.S3Region)
}
// Decrypt credentials before use
accessKey, err := encryptor.Decrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
}
secretKey, err := encryptor.Decrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
}
// Configure bucket lookup strategy
bucketLookup := minio.BucketLookupAuto
if s.S3UseVirtualHostedStyle {
bucketLookup = minio.BucketLookupDNS
}
// Initialize the MinIO client
minioClient, err := minio.New(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(s.S3AccessKey, s.S3SecretKey, ""),
Secure: useSSL,
Region: s.S3Region,
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
Secure: useSSL,
Region: s.S3Region,
BucketLookup: bucketLookup,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize MinIO client: %w", err)

View File

@@ -30,17 +30,21 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
if storage.NASStorage != nil {
storage.NASStorage.StorageID = storage.ID
}
case StorageTypeAzureBlob:
if storage.AzureBlobStorage != nil {
storage.AzureBlobStorage.StorageID = storage.ID
}
}
if storage.ID == uuid.Nil {
if err := tx.Create(storage).
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage").
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage").
Error; err != nil {
return err
}
} else {
if err := tx.Save(storage).
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage").
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage").
Error; err != nil {
return err
}
@@ -75,6 +79,13 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
return err
}
}
case StorageTypeAzureBlob:
if storage.AzureBlobStorage != nil {
storage.AzureBlobStorage.StorageID = storage.ID // Ensure ID is set
if err := tx.Save(storage.AzureBlobStorage).Error; err != nil {
return err
}
}
}
return nil
@@ -96,6 +107,7 @@ func (r *StorageRepository) FindByID(id uuid.UUID) (*Storage, error) {
Preload("S3Storage").
Preload("GoogleDriveStorage").
Preload("NASStorage").
Preload("AzureBlobStorage").
Where("id = ?", id).
First(&s).Error; err != nil {
return nil, err
@@ -113,6 +125,7 @@ func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage
Preload("S3Storage").
Preload("GoogleDriveStorage").
Preload("NASStorage").
Preload("AzureBlobStorage").
Where("workspace_id = ?", workspaceID).
Order("name ASC").
Find(&storages).Error; err != nil {
@@ -150,6 +163,12 @@ func (r *StorageRepository) Delete(s *Storage) error {
return err
}
}
case StorageTypeAzureBlob:
if s.AzureBlobStorage != nil {
if err := tx.Delete(s.AzureBlobStorage).Error; err != nil {
return err
}
}
}
// Delete the main storage

View File

@@ -7,6 +7,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -15,6 +16,7 @@ type StorageService struct {
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *StorageService) SaveStorage(
@@ -44,7 +46,11 @@ func (s *StorageService) SaveStorage(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -61,7 +67,11 @@ func (s *StorageService) SaveStorage(
} else {
storage.WorkspaceID = workspaceID
if err := storage.Validate(); err != nil {
if err := storage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := storage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -174,7 +184,7 @@ func (s *StorageService) TestStorageConnection(
return errors.New("insufficient permissions to test storage in this workspace")
}
err = storage.TestConnection()
err = storage.TestConnection(s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
storage.LastSaveError = &lastSaveError
@@ -207,7 +217,7 @@ func (s *StorageService) TestStorageConnectionDirect(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -216,7 +226,7 @@ func (s *StorageService) TestStorageConnectionDirect(
usingStorage = storage
}
return usingStorage.TestConnection()
return usingStorage.TestConnection(s.fieldEncryptor)
}
func (s *StorageService) GetStorageByID(

View File

@@ -1,31 +1,36 @@
package tests
import (
"context"
"encoding/json"
"fmt"
"net/http"
"os"
"path/filepath"
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups"
usecases_postgresql_backup "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/restores/models"
usecases_postgresql_restore "postgresus-backend/internal/features/restores/usecases/postgresql"
"postgresus-backend/internal/features/storages"
local_storage "postgresus-backend/internal/features/storages/models/local"
"postgresus-backend/internal/util/period"
"postgresus-backend/internal/util/tools"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/restores"
restores_enums "postgresus-backend/internal/features/restores/enums"
restores_models "postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
const createAndFillTableQuery = `
@@ -61,7 +66,6 @@ type TestDataItem struct {
CreatedAt time.Time `db:"created_at"`
}
// Main test functions for each PostgreSQL version
func Test_BackupAndRestorePostgresql_RestoreIsSuccesful(t *testing.T) {
env := config.GetEnv()
cases := []struct {
@@ -79,17 +83,38 @@ func Test_BackupAndRestorePostgresql_RestoreIsSuccesful(t *testing.T) {
}
for _, tc := range cases {
tc := tc // capture loop variable
t.Run(tc.name, func(t *testing.T) {
t.Parallel() // Enable parallel execution
t.Parallel()
testBackupRestoreForVersion(t, tc.version, tc.port)
})
}
}
// Run a test for a specific PostgreSQL version
func Test_BackupAndRestorePostgresqlWithEncryption_RestoreIsSuccessful(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
{"PostgreSQL 18", "18", env.TestPostgres18Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
testBackupRestoreWithEncryptionForVersion(t, tc.version, tc.port)
})
}
}
func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
// Connect to pre-configured PostgreSQL container
container, err := connectToPostgresContainer(pgVersion, port)
assert.NoError(t, err)
defer func() {
@@ -101,55 +126,30 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
_, err = container.DB.Exec(createAndFillTableQuery)
assert.NoError(t, err)
// Prepare data for backup
backupID := uuid.New()
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
pgVersionEnum := tools.GetPostgresqlVersionEnum(pgVersion)
backupDb := &databases.Database{
ID: uuid.New(),
Type: databases.DatabaseTypePostgres,
Name: "Test Database",
Postgresql: &pgtypes.PostgresqlDatabase{
Version: pgVersionEnum,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
},
}
storageID := uuid.New()
backupConfig := &backups_config.BackupConfig{
DatabaseID: backupDb.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodDay,
BackupInterval: &intervals.Interval{Interval: intervals.IntervalDaily},
StorageID: &storageID,
CpuCount: 1,
}
storage := &storages.Storage{
WorkspaceID: uuid.New(),
Type: storages.StorageTypeLocal,
Name: "Test Storage",
LocalStorage: &local_storage.LocalStorage{},
}
// Make backup
progressTracker := func(completedMBs float64) {}
err = usecases_postgresql_backup.GetCreatePostgresqlBackupUsecase().Execute(
context.Background(),
backupID,
backupConfig,
backupDb,
storage,
progressTracker,
database := createDatabaseViaAPI(
t, router, "Test Database", workspace.ID,
pgVersionEnum, container.Host, container.Port,
container.Username, container.Password, container.Database,
user.Token,
)
assert.NoError(t, err)
// Create new database
enableBackupsViaAPI(
t, router, database.ID, storage.ID,
backups_config.BackupEncryptionNone, user.Token,
)
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
newDBName := "restoreddb"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
@@ -157,43 +157,22 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
// Connect to the new database
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
assert.NoError(t, err)
defer newDB.Close()
// Setup data for restore
completedBackup := &backups.Backup{
ID: backupID,
DatabaseID: backupDb.ID,
StorageID: storage.ID,
Status: backups.BackupStatusCompleted,
CreatedAt: time.Now().UTC(),
}
createRestoreViaAPI(
t, router, backup.ID, pgVersionEnum,
container.Host, container.Port,
container.Username, container.Password, newDBName,
user.Token,
)
restoreID := uuid.New()
restore := models.Restore{
ID: restoreID,
Backup: completedBackup,
Postgresql: &pgtypes.PostgresqlDatabase{
Version: pgVersionEnum,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &newDBName,
IsHttps: false,
},
}
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
// Restore the backup
restoreBackupUC := usecases_postgresql_restore.GetRestorePostgresqlBackupUsecase()
err = restoreBackupUC.Execute(backupDb, backupConfig, restore, completedBackup, storage)
assert.NoError(t, err)
// Verify restored table exists
var tableExists bool
err = newDB.Get(
&tableExists,
@@ -202,17 +181,329 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
// Verify data integrity
verifyDataIntegrity(t, container.DB, newDB)
// Clean up the backup file after the test
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backupID.String()))
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
t.Logf("Warning: Failed to delete backup file: %v", err)
}
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+user.Token,
http.StatusNoContent,
)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func testBackupRestoreWithEncryptionForVersion(t *testing.T, pgVersion string, port string) {
container, err := connectToPostgresContainer(pgVersion, port)
assert.NoError(t, err)
defer func() {
if container.DB != nil {
container.DB.Close()
}
}()
_, err = container.DB.Exec(createAndFillTableQuery)
assert.NoError(t, err)
router := createTestRouter()
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
pgVersionEnum := tools.GetPostgresqlVersionEnum(pgVersion)
database := createDatabaseViaAPI(
t, router, "Test Database", workspace.ID,
pgVersionEnum, container.Host, container.Port,
container.Username, container.Password, container.Database,
user.Token,
)
enableBackupsViaAPI(
t, router, database.ID, storage.ID,
backups_config.BackupEncryptionEncrypted, user.Token,
)
createBackupViaAPI(t, router, database.ID, user.Token)
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
newDBName := "restoreddb_encrypted"
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
assert.NoError(t, err)
newDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, newDBName)
newDB, err := sqlx.Connect("postgres", newDSN)
assert.NoError(t, err)
defer newDB.Close()
createRestoreViaAPI(
t, router, backup.ID, pgVersionEnum,
container.Host, container.Port,
container.Username, container.Password, newDBName,
user.Token,
)
restore := waitForRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
var tableExists bool
err = newDB.Get(
&tableExists,
"SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'test_data')",
)
assert.NoError(t, err)
assert.True(t, tableExists, "Table 'test_data' should exist in restored database")
verifyDataIntegrity(t, container.DB, newDB)
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
if err != nil {
t.Logf("Warning: Failed to delete backup file: %v", err)
}
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+user.Token,
http.StatusNoContent,
)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
restores.GetRestoreController(),
)
return router
}
func waitForBackupCompletion(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
token string,
timeout time.Duration,
) *backups.Backup {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
for {
if time.Since(startTime) > timeout {
t.Fatalf("Timeout waiting for backup completion after %v", timeout)
}
var response backups.GetBackupsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backups?database_id=%s&limit=1", databaseID.String()),
"Bearer "+token,
http.StatusOK,
&response,
)
if len(response.Backups) > 0 {
backup := response.Backups[0]
if backup.Status == backups.BackupStatusCompleted {
return backup
}
if backup.Status == backups.BackupStatusFailed {
t.Fatalf("Backup failed: %v", backup.FailMessage)
}
}
time.Sleep(pollInterval)
}
}
func waitForRestoreCompletion(
t *testing.T,
router *gin.Engine,
backupID uuid.UUID,
token string,
timeout time.Duration,
) *restores_models.Restore {
startTime := time.Now()
pollInterval := 500 * time.Millisecond
for {
if time.Since(startTime) > timeout {
t.Fatalf("Timeout waiting for restore completion after %v", timeout)
}
var restores []*restores_models.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backupID.String()),
"Bearer "+token,
http.StatusOK,
&restores,
)
for _, restore := range restores {
if restore.Status == restores_enums.RestoreStatusCompleted {
return restore
}
if restore.Status == restores_enums.RestoreStatusFailed {
t.Fatalf("Restore failed: %v", restore.FailMessage)
}
}
time.Sleep(pollInterval)
}
}
func createDatabaseViaAPI(
t *testing.T,
router *gin.Engine,
name string,
workspaceID uuid.UUID,
pgVersion tools.PostgresqlVersion,
host string,
port int,
username string,
password string,
database string,
token string,
) *databases.Database {
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
Type: databases.DatabaseTypePostgres,
Postgresql: &pgtypes.PostgresqlDatabase{
Version: pgVersion,
Host: host,
Port: port,
Username: username,
Password: password,
Database: &database,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
t.Fatalf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String())
}
var createdDatabase databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &createdDatabase); err != nil {
t.Fatalf("Failed to unmarshal database response: %v", err)
}
return &createdDatabase
}
func enableBackupsViaAPI(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
storageID uuid.UUID,
encryption backups_config.BackupEncryption,
token string,
) {
var backupConfig backups_config.BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/backup-configs/database/%s", databaseID.String()),
"Bearer "+token,
http.StatusOK,
&backupConfig,
)
storage := &storages.Storage{ID: storageID}
backupConfig.IsBackupsEnabled = true
backupConfig.Storage = storage
backupConfig.Encryption = encryption
test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+token,
backupConfig,
http.StatusOK,
)
}
func createBackupViaAPI(
t *testing.T,
router *gin.Engine,
databaseID uuid.UUID,
token string,
) {
request := backups.MakeBackupRequest{DatabaseID: databaseID}
test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+token,
request,
http.StatusOK,
)
}
func createRestoreViaAPI(
t *testing.T,
router *gin.Engine,
backupID uuid.UUID,
pgVersion tools.PostgresqlVersion,
host string,
port int,
username string,
password string,
database string,
token string,
) {
request := restores.RestoreBackupRequest{
PostgresqlDatabase: &pgtypes.PostgresqlDatabase{
Version: pgVersion,
Host: host,
Port: port,
Username: username,
Password: password,
Database: &database,
},
}
test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backupID.String()),
"Bearer "+token,
request,
http.StatusOK,
)
}
// verifyDataIntegrity compares data in the original and restored databases
func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
var originalData []TestDataItem
var restoredData []TestDataItem
@@ -225,7 +516,6 @@ func verifyDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB)
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")
// Only compare data if both slices have elements (to avoid panic)
if len(originalData) > 0 && len(restoredData) > 0 {
for i := range originalData {
assert.Equal(t, originalData[i].ID, restoredData[i].ID, "ID should match")

View File

@@ -1,7 +1,7 @@
package users_models
type SecretKey struct {
Secret string `gorm:"column:secret"`
Secret string `gorm:"column:secret" json:"-"`
}
func (SecretKey) TableName() string {

View File

@@ -0,0 +1,17 @@
package users_repositories
var secretKeyRepository = &SecretKeyRepository{}
var userRepository = &UserRepository{}
var usersSettingsRepository = &UsersSettingsRepository{}
func GetSecretKeyRepository() *SecretKeyRepository {
return secretKeyRepository
}
func GetUserRepository() *UserRepository {
return userRepository
}
func GetUsersSettingsRepository() *UsersSettingsRepository {
return usersSettingsRepository
}

View File

@@ -14,9 +14,7 @@ type SecretKeyRepository struct{}
func (r *SecretKeyRepository) GetSecretKey() (string, error) {
var secretKey user_models.SecretKey
if err := storage.
GetDb().
First(&secretKey).Error; err != nil {
if err := storage.GetDb().First(&secretKey).Error; err != nil {
// create a new secret key if not found
if errors.Is(err, gorm.ErrRecordNotFound) {
newSecretKey := user_models.SecretKey{

View File

@@ -1,25 +1,19 @@
package users_services
import (
user_repositories "postgresus-backend/internal/features/users/repositories"
)
var secretKeyRepository = &user_repositories.SecretKeyRepository{}
var userRepository = &user_repositories.UserRepository{}
var usersSettingsRepository = &user_repositories.UsersSettingsRepository{}
import users_repositories "postgresus-backend/internal/features/users/repositories"
var userService = &UserService{
userRepository,
secretKeyRepository,
users_repositories.GetUserRepository(),
users_repositories.GetSecretKeyRepository(),
settingsService,
nil,
}
var settingsService = &SettingsService{
usersSettingsRepository,
users_repositories.GetUsersSettingsRepository(),
nil,
}
var managementService = &UserManagementService{
userRepository,
users_repositories.GetUserRepository(),
nil,
}

View File

@@ -309,15 +309,6 @@ func (s *UserService) ChangeUserPasswordByEmail(email string, newPassword string
}
func (s *UserService) ChangeUserPassword(userID uuid.UUID, newPassword string) error {
user, err := s.userRepository.GetUserByID(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if !user.HasPassword() {
return errors.New("user has no password set")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)

View File

@@ -0,0 +1,11 @@
package encryption
import users_repositories "postgresus-backend/internal/features/users/repositories"
var fieldEncryptor = &SecretKeyFieldEncryptor{
users_repositories.GetSecretKeyRepository(),
}
func GetFieldEncryptor() FieldEncryptor {
return fieldEncryptor
}

View File

@@ -0,0 +1,15 @@
package encryption
import "github.com/google/uuid"
type FieldEncryptor interface {
// Encrypt encrypts a plaintext string and returns an encrypted string.
// If the string is already encrypted, returns it as-is.
// Empty strings are returned unchanged.
Encrypt(itemID uuid.UUID, plaintext string) (string, error)
// Decrypt decrypts an encrypted string and returns a plaintext string.
// If the string is not encrypted, returns it as-is.
// Empty strings are returned unchanged.
Decrypt(itemID uuid.UUID, ciphertext string) (string, error)
}

View File

@@ -0,0 +1,121 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"strings"
users_repositories "postgresus-backend/internal/features/users/repositories"
"github.com/google/uuid"
)
const encryptedPrefix = "enc:"
type SecretKeyFieldEncryptor struct {
secretKeyRepository *users_repositories.SecretKeyRepository
}
func (e *SecretKeyFieldEncryptor) Encrypt(itemID uuid.UUID, plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
if e.isEncrypted(plaintext) {
return plaintext, nil
}
masterKey, err := e.secretKeyRepository.GetSecretKey()
if err != nil {
return "", fmt.Errorf("failed to get master key: %w", err)
}
block, err := aes.NewCipher([]byte(masterKey)[:32])
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
nonce := e.deriveNonce(itemID, masterKey, gcm.NonceSize())
ciphertext := gcm.Seal(nil, nonce, []byte(plaintext), nil)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
ciphertextBase64 := base64.StdEncoding.EncodeToString(ciphertext)
return fmt.Sprintf("%s%s:%s", encryptedPrefix, nonceBase64, ciphertextBase64), nil
}
func (e *SecretKeyFieldEncryptor) Decrypt(itemID uuid.UUID, ciphertext string) (string, error) {
if ciphertext == "" {
return "", nil
}
if !e.isEncrypted(ciphertext) {
return ciphertext, nil
}
parts := strings.SplitN(ciphertext, ":", 3)
if len(parts) != 3 {
return "", errors.New("invalid encrypted format")
}
nonceBase64 := parts[1]
ciphertextBase64 := parts[2]
nonce, err := base64.StdEncoding.DecodeString(nonceBase64)
if err != nil {
return "", fmt.Errorf("failed to decode nonce: %w", err)
}
encryptedData, err := base64.StdEncoding.DecodeString(ciphertextBase64)
if err != nil {
return "", fmt.Errorf("failed to decode ciphertext: %w", err)
}
masterKey, err := e.secretKeyRepository.GetSecretKey()
if err != nil {
return "", fmt.Errorf("failed to get master key: %w", err)
}
block, err := aes.NewCipher([]byte(masterKey)[:32])
if err != nil {
return "", fmt.Errorf("failed to create cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("failed to create GCM: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, encryptedData, nil)
if err != nil {
return "", fmt.Errorf("failed to decrypt: %w", err)
}
return string(plaintext), nil
}
func (e *SecretKeyFieldEncryptor) isEncrypted(value string) bool {
return strings.HasPrefix(value, encryptedPrefix)
}
func (e *SecretKeyFieldEncryptor) deriveNonce(
itemID uuid.UUID,
masterKey string,
nonceSize int,
) []byte {
h := hmac.New(sha256.New, []byte(masterKey))
h.Write(itemID[:])
hash := h.Sum(nil)
return hash[:nonceSize]
}

View File

@@ -0,0 +1,120 @@
package encryption
import (
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_Encrypt_Decrypt_RoundTrip(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "my-secret-password"
encrypted, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
assert.NotEmpty(t, encrypted)
assert.NotEqual(t, plaintext, encrypted)
assert.Contains(t, encrypted, "enc:")
decrypted, err := encryptor.Decrypt(itemID, encrypted)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
func Test_Encrypt_EmptyString_ReturnsEmpty(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
encrypted, err := encryptor.Encrypt(itemID, "")
assert.NoError(t, err)
assert.Empty(t, encrypted)
}
func Test_Decrypt_EmptyString_ReturnsEmpty(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
decrypted, err := encryptor.Decrypt(itemID, "")
assert.NoError(t, err)
assert.Empty(t, decrypted)
}
func Test_Decrypt_PlaintextValue_ReturnsAsIs(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "not-encrypted-password"
decrypted, err := encryptor.Decrypt(itemID, plaintext)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted)
}
func Test_Encrypt_DetectsAlreadyEncryptedFormat(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
alreadyEncrypted := "enc:nonce:ciphertext"
result, err := encryptor.Encrypt(itemID, alreadyEncrypted)
assert.NoError(t, err)
assert.Equal(t, alreadyEncrypted, result)
}
func Test_Encrypt_SamePlaintext_DifferentItemIDs_ProducesDifferentCiphertext(t *testing.T) {
encryptor := GetFieldEncryptor()
plaintext := "shared-secret"
itemID1 := uuid.New()
itemID2 := uuid.New()
encrypted1, err := encryptor.Encrypt(itemID1, plaintext)
assert.NoError(t, err)
encrypted2, err := encryptor.Encrypt(itemID2, plaintext)
assert.NoError(t, err)
assert.NotEqual(t, encrypted1, encrypted2)
decrypted1, err := encryptor.Decrypt(itemID1, encrypted1)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted1)
decrypted2, err := encryptor.Decrypt(itemID2, encrypted2)
assert.NoError(t, err)
assert.Equal(t, plaintext, decrypted2)
}
func Test_Encrypt_AlreadyEncrypted_ReturnsAsIs(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "my-password"
encrypted1, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
encrypted2, err := encryptor.Encrypt(itemID, encrypted1)
assert.NoError(t, err)
assert.Equal(t, encrypted1, encrypted2)
}
func Test_Decrypt_MalformedData_ReturnsError(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
_, err := encryptor.Decrypt(itemID, "enc:invalid")
assert.Error(t, err)
_, err = encryptor.Decrypt(itemID, "enc:invalid:invalid-base64")
assert.Error(t, err)
}
func Test_EncryptedFormat_ContainsPrefix(t *testing.T) {
encryptor := GetFieldEncryptor()
itemID := uuid.New()
plaintext := "test-secret"
encrypted, err := encryptor.Encrypt(itemID, plaintext)
assert.NoError(t, err)
assert.Contains(t, encrypted, "enc:")
}

View File

@@ -0,0 +1,17 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE s3_storages
ADD COLUMN s3_prefix TEXT;
ALTER TABLE s3_storages
ADD COLUMN s3_use_virtual_hosted_style BOOLEAN NOT NULL DEFAULT FALSE;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE s3_storages
DROP COLUMN s3_use_virtual_hosted_style;
ALTER TABLE s3_storages
DROP COLUMN s3_prefix;
-- +goose StatementEnd

View File

@@ -0,0 +1,28 @@
-- +goose Up
-- +goose StatementBegin
CREATE TABLE azure_blob_storages (
storage_id UUID PRIMARY KEY,
auth_method TEXT NOT NULL,
connection_string TEXT,
account_name TEXT,
account_key TEXT,
container_name TEXT NOT NULL,
endpoint TEXT,
prefix TEXT
);
ALTER TABLE azure_blob_storages
ADD CONSTRAINT fk_azure_blob_storages_storage
FOREIGN KEY (storage_id)
REFERENCES storages (id)
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
DROP TABLE IF EXISTS azure_blob_storages;
-- +goose StatementEnd

View File

@@ -0,0 +1,25 @@
-- +goose Up
-- +goose StatementBegin
ALTER TABLE backup_configs
ADD COLUMN encryption TEXT NOT NULL DEFAULT 'NONE';
ALTER TABLE backups
ADD COLUMN encryption_salt TEXT,
ADD COLUMN encryption_iv TEXT,
ADD COLUMN encryption TEXT NOT NULL DEFAULT 'NONE';
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
ALTER TABLE backups
DROP COLUMN IF EXISTS encryption,
DROP COLUMN IF EXISTS encryption_iv,
DROP COLUMN IF EXISTS encryption_salt;
ALTER TABLE backup_configs
DROP COLUMN IF EXISTS encryption;
-- +goose StatementEnd

View File

@@ -0,0 +1,9 @@
<?xml version="1.0" encoding="UTF-8"?>
<!-- Uploaded to: SVG Repo, www.svgrepo.com, Generator: SVG Repo Mixer Tools -->
<svg width="800px" height="800px" viewBox="0 -28.5 256 256" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>path21</title>
<g stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<path d="M118.431947,187.698037 C151.322003,181.887937 178.48731,177.08008 178.799309,177.013916 L179.366585,176.893612 L148.31513,139.958881 C131.236843,119.644776 117.26369,102.945381 117.26369,102.849118 C117.26369,102.666861 149.32694,14.3716012 149.507189,14.057257 C149.567455,13.952452 171.38747,51.62411 202.400338,105.376064 C231.435152,155.699606 255.372949,197.191547 255.595444,197.580359 L255.999996,198.287301 L157.315912,198.274572 L58.6318456,198.261895 L118.431947,187.698073 L118.431947,187.698037 Z M-4.03864498e-06,176.434723 C-4.03864498e-06,176.382721 14.631291,150.983941 32.5139844,119.992969 L65.0279676,63.6457518 L102.919257,31.8473052 C123.759465,14.3581634 140.866667,0.0274832751 140.935253,0.00062917799 C141.003839,-0.0247829554 140.729691,0.665213042 140.326034,1.53468179 C139.922377,2.40415053 121.407304,42.1170321 99.1814268,89.7855264 L58.7707514,176.455514 L29.3853737,176.492355 C13.2234196,176.512639 -4.03864498e-06,176.486664 -4.03864498e-06,176.434703 L-4.03864498e-06,176.434723 Z" fill="#0089D6" fill-rule="nonzero">
</path>
</g>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -4,3 +4,4 @@ export { BackupStatus } from './model/BackupStatus';
export type { Backup } from './model/Backup';
export type { BackupConfig } from './model/BackupConfig';
export { BackupNotificationType } from './model/BackupNotificationType';
export { BackupEncryption } from './model/BackupEncryption';

View File

@@ -1,5 +1,6 @@
import type { Database } from '../../databases/model/Database';
import type { Storage } from '../../storages';
import { BackupEncryption } from './BackupEncryption';
import { BackupStatus } from './BackupStatus';
export interface Backup {
@@ -15,5 +16,7 @@ export interface Backup {
backupDurationMs: number;
encryption: BackupEncryption;
createdAt: Date;
}

View File

@@ -1,6 +1,7 @@
import type { Period } from '../../databases/model/Period';
import type { Interval } from '../../intervals';
import type { Storage } from '../../storages';
import { BackupEncryption } from './BackupEncryption';
import type { BackupNotificationType } from './BackupNotificationType';
export interface BackupConfig {
@@ -14,4 +15,5 @@ export interface BackupConfig {
cpuCount: number;
isRetryIfFailed: boolean;
maxFailedTriesCount: number;
encryption: BackupEncryption;
}

View File

@@ -0,0 +1,4 @@
export enum BackupEncryption {
NONE = 'NONE',
ENCRYPTED = 'ENCRYPTED',
}

View File

@@ -7,3 +7,4 @@ export { type NASStorage } from './models/NASStorage';
export { getStorageLogoFromType } from './models/getStorageLogoFromType';
export { getStorageNameFromType } from './models/getStorageNameFromType';
export { type GoogleDriveStorage } from './models/GoogleDriveStorage';
export { type AzureBlobStorage } from './models/AzureBlobStorage';

View File

@@ -0,0 +1,9 @@
export interface AzureBlobStorage {
authMethod: 'CONNECTION_STRING' | 'ACCOUNT_KEY';
connectionString: string;
accountName: string;
accountKey: string;
containerName: string;
endpoint?: string;
prefix?: string;
}

View File

@@ -4,4 +4,6 @@ export interface S3Storage {
s3AccessKey: string;
s3SecretKey: string;
s3Endpoint?: string;
s3Prefix?: string;
s3UseVirtualHostedStyle?: boolean;
}

View File

@@ -1,3 +1,4 @@
import type { AzureBlobStorage } from './AzureBlobStorage';
import type { GoogleDriveStorage } from './GoogleDriveStorage';
import type { LocalStorage } from './LocalStorage';
import type { NASStorage } from './NASStorage';
@@ -16,4 +17,5 @@ export interface Storage {
s3Storage?: S3Storage;
googleDriveStorage?: GoogleDriveStorage;
nasStorage?: NASStorage;
azureBlobStorage?: AzureBlobStorage;
}

View File

@@ -3,4 +3,5 @@ export enum StorageType {
S3 = 'S3',
GOOGLE_DRIVE = 'GOOGLE_DRIVE',
NAS = 'NAS',
AZURE_BLOB = 'AZURE_BLOB',
}

View File

@@ -10,6 +10,8 @@ export const getStorageLogoFromType = (type: StorageType) => {
return '/icons/storages/google-drive.svg';
case StorageType.NAS:
return '/icons/storages/nas.svg';
case StorageType.AZURE_BLOB:
return '/icons/storages/azure.svg';
default:
return '';
}

View File

@@ -10,6 +10,8 @@ export const getStorageNameFromType = (type: StorageType) => {
return 'Google Drive';
case StorageType.NAS:
return 'NAS';
case StorageType.AZURE_BLOB:
return 'Azure Blob Storage';
default:
return '';
}

View File

@@ -6,6 +6,7 @@ import {
DownloadOutlined,
ExclamationCircleOutlined,
InfoCircleOutlined,
LockOutlined,
SyncOutlined,
} from '@ant-design/icons';
import { Button, Modal, Spin, Table, Tooltip } from 'antd';
@@ -16,6 +17,7 @@ import { useEffect, useRef, useState } from 'react';
import {
type Backup,
type BackupConfig,
BackupEncryption,
BackupStatus,
backupConfigApi,
backupsApi,
@@ -318,6 +320,12 @@ export const BackupsComponent = ({ database, isCanManageDBs, scrollContainerRef
<div className="flex items-center text-green-600">
<CheckCircleOutlined className="mr-2" style={{ fontSize: 16 }} />
<div>Successful</div>
{record.encryption === BackupEncryption.ENCRYPTED && (
<Tooltip title="Encrypted">
<LockOutlined className="ml-1" style={{ fontSize: 14 }} />
</Tooltip>
)}
</div>
);
}

View File

@@ -13,7 +13,7 @@ import {
import dayjs, { Dayjs } from 'dayjs';
import { useEffect, useMemo, useState } from 'react';
import { type BackupConfig, backupConfigApi } from '../../../entity/backups';
import { type BackupConfig, BackupEncryption, backupConfigApi } from '../../../entity/backups';
import { BackupNotificationType } from '../../../entity/backups/model/BackupNotificationType';
import type { Database } from '../../../entity/databases';
import { Period } from '../../../entity/databases/model/Period';
@@ -153,6 +153,7 @@ export const EditBackupConfigComponent = ({
sendNotificationsOn: [],
isRetryIfFailed: true,
maxFailedTriesCount: 3,
encryption: BackupEncryption.ENCRYPTED,
});
}
loadStorages();
@@ -195,6 +196,7 @@ export const EditBackupConfigComponent = ({
(Boolean(backupConfig.storePeriod) &&
Boolean(backupConfig.storage?.id) &&
Boolean(backupConfig.cpuCount) &&
Boolean(backupConfig.encryption) &&
Boolean(backupInterval?.interval) &&
(!backupInterval ||
((backupInterval.interval !== IntervalType.WEEKLY || displayedWeekday) &&
@@ -418,6 +420,27 @@ export const EditBackupConfigComponent = ({
)}
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Encryption</div>
<Select
value={backupConfig.encryption}
onChange={(v) => updateBackupConfig({ encryption: v })}
size="small"
className="max-w-[200px] grow"
options={[
{ label: 'None', value: BackupEncryption.NONE },
{ label: 'Encrypt backup files', value: BackupEncryption.ENCRYPTED },
]}
/>
<Tooltip
className="cursor-pointer"
title="If backup is encrypted, backup files in your storage (S3, local, etc.) cannot be used directly. You can restore backups through Postgresus or download them unencrypted via the 'Download' button."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
{backupConfig.isBackupsEnabled && (
<>
<div className="mt-4 mb-1 flex w-full items-start">

View File

@@ -1,8 +1,10 @@
import { InfoCircleOutlined } from '@ant-design/icons';
import { Tooltip } from 'antd';
import dayjs from 'dayjs';
import { useMemo } from 'react';
import { useEffect, useState } from 'react';
import { type BackupConfig, backupConfigApi } from '../../../entity/backups';
import { type BackupConfig, BackupEncryption, backupConfigApi } from '../../../entity/backups';
import { BackupNotificationType } from '../../../entity/backups/model/BackupNotificationType';
import type { Database } from '../../../entity/databases';
import { Period } from '../../../entity/databases/model/Period';
@@ -167,6 +169,18 @@ export const ShowBackupConfigComponent = ({ database }: Props) => {
</div>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Encryption</div>
<div>{backupConfig.encryption === BackupEncryption.ENCRYPTED ? 'Enabled' : 'None'}</div>
<Tooltip
className="cursor-pointer"
title="If backup is encrypted, backup files in your storage (S3, local, etc.) cannot be used directly. You can restore backups through Postgresus or download them unencrypted via the 'Download' button."
>
<InfoCircleOutlined className="ml-2" style={{ color: 'gray' }} />
</Tooltip>
</div>
<div className="mb-1 flex w-full items-center">
<div className="min-w-[150px]">Notifications</div>
<div>

View File

@@ -16,7 +16,7 @@ import { EditDatabaseSpecificDataComponent } from './edit/EditDatabaseSpecificDa
interface Props {
workspaceId: string;
onCreated: () => void;
onCreated: (databaseId: string) => void;
onClose: () => void;
}
@@ -58,7 +58,7 @@ export const CreateDatabaseComponent = ({ workspaceId, onCreated, onClose }: Pro
await backupsApi.makeBackup(createdDatabase.id);
}
onCreated();
onCreated(createdDatabase.id);
onClose();
} catch (error) {
alert(error);

View File

@@ -14,6 +14,8 @@ interface Props {
isCanManageDBs: boolean;
}
const SELECTED_DATABASE_STORAGE_KEY = 'selectedDatabaseId';
export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }: Props) => {
const [isLoading, setIsLoading] = useState(true);
const [databases, setDatabases] = useState<Database[]>([]);
@@ -22,7 +24,16 @@ export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }:
const [isShowAddDatabase, setIsShowAddDatabase] = useState(false);
const [selectedDatabaseId, setSelectedDatabaseId] = useState<string | undefined>(undefined);
const loadDatabases = (isSilent = false) => {
const updateSelectedDatabaseId = (databaseId: string | undefined) => {
setSelectedDatabaseId(databaseId);
if (databaseId) {
localStorage.setItem(`${SELECTED_DATABASE_STORAGE_KEY}_${workspace.id}`, databaseId);
} else {
localStorage.removeItem(`${SELECTED_DATABASE_STORAGE_KEY}_${workspace.id}`);
}
};
const loadDatabases = (isSilent = false, selectDatabaseId?: string) => {
if (!isSilent) {
setIsLoading(true);
}
@@ -31,8 +42,17 @@ export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }:
.getDatabases(workspace.id)
.then((databases) => {
setDatabases(databases);
if (!selectedDatabaseId && !isSilent) {
setSelectedDatabaseId(databases[0]?.id);
if (selectDatabaseId) {
updateSelectedDatabaseId(selectDatabaseId);
} else if (!selectedDatabaseId && !isSilent) {
const savedDatabaseId = localStorage.getItem(
`${SELECTED_DATABASE_STORAGE_KEY}_${workspace.id}`,
);
const databaseToSelect =
savedDatabaseId && databases.some((db) => db.id === savedDatabaseId)
? savedDatabaseId
: databases[0]?.id;
updateSelectedDatabaseId(databaseToSelect);
}
})
.catch((e) => alert(e.message))
@@ -95,7 +115,7 @@ export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }:
key={database.id}
database={database}
selectedDatabaseId={selectedDatabaseId}
setSelectedDatabaseId={setSelectedDatabaseId}
setSelectedDatabaseId={updateSelectedDatabaseId}
/>
))
: searchQuery && (
@@ -119,10 +139,11 @@ export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }:
loadDatabases();
}}
onDatabaseDeleted={() => {
loadDatabases();
setSelectedDatabaseId(
databases.filter((database) => database.id !== selectedDatabaseId)[0]?.id,
const remainingDatabases = databases.filter(
(database) => database.id !== selectedDatabaseId,
);
updateSelectedDatabaseId(remainingDatabases[0]?.id);
loadDatabases();
}}
isCanManageDBs={isCanManageDBs}
/>
@@ -141,8 +162,8 @@ export const DatabasesComponent = ({ contentHeight, workspace, isCanManageDBs }:
<CreateDatabaseComponent
workspaceId={workspace.id}
onCreated={() => {
loadDatabases();
onCreated={(databaseId) => {
loadDatabases(false, databaseId);
setIsShowAddDatabase(false);
}}
onClose={() => setIsShowAddDatabase(false)}

View File

@@ -255,7 +255,10 @@ export function EditNotifierComponent({
<EditTelegramNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
@@ -263,7 +266,10 @@ export function EditNotifierComponent({
<EditEmailNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
@@ -271,7 +277,10 @@ export function EditNotifierComponent({
<EditWebhookNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
@@ -279,7 +288,10 @@ export function EditNotifierComponent({
<EditSlackNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
@@ -287,14 +299,20 @@ export function EditNotifierComponent({
<EditDiscordNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
{notifier?.notifierType === NotifierType.TEAMS && (
<EditTeamsNotifierComponent
notifier={notifier}
setNotifier={setNotifier}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestNotificationSuccess(false);
}}
/>
)}
</div>

View File

@@ -5,10 +5,10 @@ import type { Notifier } from '../../../../../entity/notifiers';
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditDiscordNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditDiscordNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
return (
<>
<div className="flex">
@@ -26,7 +26,7 @@ export function EditDiscordNotifierComponent({ notifier, setNotifier, setIsUnsav
channelWebhookUrl: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"

View File

@@ -6,10 +6,10 @@ import type { Notifier } from '../../../../../entity/notifiers';
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditEmailNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
return (
<>
<div className="mb-1 flex items-center">
@@ -26,7 +26,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
targetEmail: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
@@ -52,7 +52,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
smtpHost: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
@@ -75,7 +75,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
smtpPort: Number(e.target.value),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
@@ -97,7 +97,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
smtpUser: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
@@ -120,7 +120,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
smtpPassword: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"
@@ -142,7 +142,7 @@ export function EditEmailNotifierComponent({ notifier, setNotifier, setIsUnsaved
from: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full max-w-[250px]"

View File

@@ -5,10 +5,10 @@ import type { Notifier } from '../../../../../entity/notifiers';
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditSlackNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditSlackNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
return (
<>
<div className="mb-1 ml-[130px] max-w-[200px]" style={{ lineHeight: 1 }}>
@@ -38,7 +38,7 @@ export function EditSlackNotifierComponent({ notifier, setNotifier, setIsUnsaved
botToken: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"
@@ -63,7 +63,7 @@ export function EditSlackNotifierComponent({ notifier, setNotifier, setIsUnsaved
targetChatId: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"

View File

@@ -7,10 +7,10 @@ import type { Notifier } from '../../../../../entity/notifiers';
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditTeamsNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditTeamsNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
const value = notifier?.teamsNotifier?.powerAutomateUrl || '';
const onChange = (e: React.ChangeEvent<HTMLInputElement>) => {
@@ -22,7 +22,7 @@ export function EditTeamsNotifierComponent({ notifier, setNotifier, setIsUnsaved
powerAutomateUrl,
},
});
setIsUnsaved(true);
setUnsaved();
};
return (

View File

@@ -7,10 +7,10 @@ import type { Notifier } from '../../../../../entity/notifiers';
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditTelegramNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditTelegramNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
const [isShowHowToGetChatId, setIsShowHowToGetChatId] = useState(false);
useEffect(() => {
@@ -42,7 +42,7 @@ export function EditTelegramNotifierComponent({ notifier, setNotifier, setIsUnsa
botToken: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"
@@ -78,7 +78,7 @@ export function EditTelegramNotifierComponent({ notifier, setNotifier, setIsUnsa
targetChatId: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"
@@ -137,7 +137,7 @@ export function EditTelegramNotifierComponent({ notifier, setNotifier, setIsUnsa
threadId: checked ? notifier.telegramNotifier.threadId : undefined,
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
/>
@@ -171,7 +171,7 @@ export function EditTelegramNotifierComponent({ notifier, setNotifier, setIsUnsa
threadId: !isNaN(threadId!) ? threadId : undefined,
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"

View File

@@ -7,10 +7,10 @@ import { WebhookMethod } from '../../../../../entity/notifiers/models/webhook/We
interface Props {
notifier: Notifier;
setNotifier: (notifier: Notifier) => void;
setIsUnsaved: (isUnsaved: boolean) => void;
setUnsaved: () => void;
}
export function EditWebhookNotifierComponent({ notifier, setNotifier, setIsUnsaved }: Props) {
export function EditWebhookNotifierComponent({ notifier, setNotifier, setUnsaved }: Props) {
return (
<>
<div className="flex items-center">
@@ -27,7 +27,7 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setIsUnsav
webhookUrl: e.target.value.trim(),
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"
@@ -50,7 +50,7 @@ export function EditWebhookNotifierComponent({ notifier, setNotifier, setIsUnsav
webhookMethod: value,
},
});
setIsUnsaved(true);
setUnsaved();
}}
size="small"
className="w-full"

View File

@@ -8,6 +8,7 @@ import {
storageApi,
} from '../../../../entity/storages';
import { ToastHelper } from '../../../../shared/toast';
import { EditAzureBlobStorageComponent } from './storages/EditAzureBlobStorageComponent';
import { EditGoogleDriveStorageComponent } from './storages/EditGoogleDriveStorageComponent';
import { EditNASStorageComponent } from './storages/EditNASStorageComponent';
import { EditS3StorageComponent } from './storages/EditS3StorageComponent';
@@ -80,6 +81,7 @@ export function EditStorageComponent({
storage.localStorage = undefined;
storage.s3Storage = undefined;
storage.googleDriveStorage = undefined;
storage.azureBlobStorage = undefined;
if (type === StorageType.LOCAL) {
storage.localStorage = {};
@@ -115,6 +117,18 @@ export function EditStorageComponent({
};
}
if (type === StorageType.AZURE_BLOB) {
storage.azureBlobStorage = {
authMethod: 'ACCOUNT_KEY',
connectionString: '',
accountName: '',
accountKey: '',
containerName: '',
endpoint: '',
prefix: '',
};
}
setStorage(
JSON.parse(
JSON.stringify({
@@ -197,6 +211,26 @@ export function EditStorageComponent({
);
}
if (storage.type === StorageType.AZURE_BLOB) {
if (storage.id) {
return storage.azureBlobStorage?.containerName;
}
const isContainerNameFilled = storage.azureBlobStorage?.containerName;
if (storage.azureBlobStorage?.authMethod === 'CONNECTION_STRING') {
return isContainerNameFilled && storage.azureBlobStorage?.connectionString;
}
if (storage.azureBlobStorage?.authMethod === 'ACCOUNT_KEY') {
return (
isContainerNameFilled &&
storage.azureBlobStorage?.accountName &&
storage.azureBlobStorage?.accountKey
);
}
}
return false;
};
@@ -231,6 +265,7 @@ export function EditStorageComponent({
{ label: 'S3', value: StorageType.S3 },
{ label: 'Google Drive', value: StorageType.GOOGLE_DRIVE },
{ label: 'NAS', value: StorageType.NAS },
{ label: 'Azure Blob Storage', value: StorageType.AZURE_BLOB },
]}
onChange={(value) => {
setStorageType(value);
@@ -250,7 +285,10 @@ export function EditStorageComponent({
<EditS3StorageComponent
storage={storage}
setStorage={setStorage}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestConnectionSuccess(false);
}}
/>
)}
@@ -258,7 +296,10 @@ export function EditStorageComponent({
<EditGoogleDriveStorageComponent
storage={storage}
setStorage={setStorage}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestConnectionSuccess(false);
}}
/>
)}
@@ -266,7 +307,21 @@ export function EditStorageComponent({
<EditNASStorageComponent
storage={storage}
setStorage={setStorage}
setIsUnsaved={setIsUnsaved}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestConnectionSuccess(false);
}}
/>
)}
{storage?.type === StorageType.AZURE_BLOB && (
<EditAzureBlobStorageComponent
storage={storage}
setStorage={setStorage}
setUnsaved={() => {
setIsUnsaved(true);
setIsTestConnectionSuccess(false);
}}
/>
)}
</div>

Some files were not shown because too many files have changed in this diff Show More