Compare commits

...

116 Commits

Author SHA1 Message Date
Rostislav Dugin
eb8e5aa428 FEATURE (storages): Add SFTP 2025-12-19 23:24:16 +03:00
github-actions[bot]
1f030bd8fb Update CITATION.cff to v2.8.1 2025-12-19 11:44:37 +00:00
Rostislav Dugin
b278a79104 FIX (databases): Remove optional text from db name field 2025-12-19 14:28:54 +03:00
github-actions[bot]
b74ae734af Update CITATION.cff to v2.8.0 2025-12-18 16:13:17 +00:00
Rostislav Dugin
d21a9398c6 FIX (Dockerfile): Upgrade Go version 2025-12-18 18:57:26 +03:00
Rostislav Dugin
6ad7b95b7d FIX (go tidy): Run go mod tidy 2025-12-18 18:42:02 +03:00
Rostislav Dugin
8432d1626f FIX (linting): Increase lint timeout 2025-12-18 18:36:11 +03:00
Rostislav Dugin
d7f631fa93 FIX (golangci): Upgrade version of golangci 2025-12-18 18:33:41 +03:00
Rostislav Dugin
c3fb2aa529 FIX (golangci): Upgrade version of golangci 2025-12-18 18:31:03 +03:00
Rostislav Dugin
1817937409 FIX (ci \ cd): Upgrade Go version 2025-12-18 18:16:37 +03:00
Rostislav Dugin
3172396668 FIX (extensions): Exclude extensions comments as well 2025-12-18 17:54:52 +03:00
Rostislav Dugin
9cd5c8c57c Merge branch 'main' of https://github.com/RostislavDugin/postgresus 2025-12-18 17:49:24 +03:00
Rostislav Dugin
d8826d85c3 FEATURE (storanges): Add rclone 2025-12-18 17:46:16 +03:00
github-actions[bot]
49fdd46cbe Update CITATION.cff to v2.7.0 2025-12-18 11:49:21 +00:00
Rostislav Dugin
c6261d434b FEATURE (restores): Allow to exclude extensions over restore 2025-12-18 14:34:32 +03:00
github-actions[bot]
918002acde Update CITATION.cff to v2.6.0 2025-12-17 14:03:33 +00:00
Rostislav Dugin
c0721a43e1 FEATURE (docs): Add code of conduct 2025-12-17 16:41:07 +03:00
Rostislav Dugin
461e15cd7a FEATURE (security): Add security md file 2025-12-17 16:33:10 +03:00
Rostislav Dugin
69a53936f5 FEATURE (citation): Add CITATION.cff 2025-12-17 16:17:43 +03:00
Rostislav Dugin
2bafec3c19 FIX (databases): Fix second opening of storage & notifier creation dialogs 2025-12-16 13:33:56 +03:00
Rostislav Dugin
422b44dfdc FEATURE (ftp): Get rid of passive mode 2025-12-14 00:01:21 +03:00
Rostislav Dugin
51d7fe54d0 Merge pull request #144 from omerkarabacak/main
FEATURE (clusters): Add cluster-based database management and bulk import
2025-12-13 22:37:35 +03:00
Omer Karabacak
6e2d63626c FEATURE (clusters): Add cluster-based database management and bulk import functionality 2025-12-13 20:32:54 +01:00
Rostislav Dugin
260c7a1188 FEATURE (frontend): Add frontend tests 2025-12-13 22:22:31 +03:00
Rostislav Dugin
ace94c144b FEATURE (storanges): Add FTP storange 2025-12-13 22:17:16 +03:00
Rostislav Dugin
b666cd9e2e Merge pull request #143 from RostislavDugin/develop
FEATURE (parsing): Add parsing connection string on DB creation
2025-12-13 13:53:30 +03:00
Rostislav Dugin
9dac63430d FEATURE (parsing): Add parsing connection string on DB creation 2025-12-13 13:50:22 +03:00
Rostislav Dugin
8217906c7a Merge pull request #139 from RostislavDugin/develop
Merge develop into main
2025-12-11 20:02:32 +03:00
Rostislav Dugin
db71a5ef7b FIX (databases): Add support dashed databases for read only users creation 2025-12-11 19:57:49 +03:00
Rostislav Dugin
df78e296b3 FEATURE (s3): Allow to skip TLS verification 2025-12-11 19:50:59 +03:00
Rostislav Dugin
fda3bf9b98 FEATURE (supabase): Add support of Supabase, schemas excluding and get rid of version in UI 2025-12-11 19:27:45 +03:00
pv-create
e19f449c60 FIX (readme): Fix typos and links
* fix typos

* fix link

* fix email param

---------

Co-authored-by: pavelvilkov <vilkovpy@mi-broker.ru>
2025-12-10 19:44:49 +03:00
Leonardo Flores
5944d7c4b6 feat(postgresus): Add schema filter for pg_dump and pg_restore (#131)
Add optional "Schemas" field to PostgreSQL database settings allowing
users to specify which schemas to include in backups (comma-separated).

This solves permission issues when backing up some of databases that
have restricted internal schemas (auth, storage, realtime).

Changes:
- Add schemas column to postgresql_databases table (migration)
- Update PostgresqlDatabase model with Schemas field
- Modify buildPgDumpArgs() to append --schema flags for each schema
- Modify pg_restore args to support --schema filtering on restore
- Add Schemas input field to frontend edit form with tooltip
- Display schemas in read-only database view

Example usage: Setting schemas to "public,drizzle" generates:
  pg_dump ... --schema public --schema drizzle
  pg_restore ... --schema public --schema drizzle
2025-12-10 13:19:15 +03:00
Unicorn-Zombie-Apocalypse
1f5c9d3d01 feat: Add support for custom Root CA configuration in Helm chart (#129)
* feat: Add support for custom Root CA configuration in Helm chart

* fix: Remove default value for customRootCA in Helm chart
2025-12-09 19:36:52 +03:00
Rostislav Dugin
d27b885fc1 FIX (postgresql): Fix version detection without minor version after major 2025-12-09 10:36:07 +03:00
Rostislav Dugin
45054bc4b5 FIX (readme): Update README about PITR 2025-12-08 22:20:41 +03:00
Rostislav Dugin
09f27019e8 FIX (postgresql): Use UTF-8 encoding for DB connection by default 2025-12-08 17:40:37 +03:00
Rostislav Dugin
cba8fdf49c FEATURE (core)!: Release 2.0 2025-12-08 10:41:36 +03:00
Rostislav Dugin
41c72cf7b6 FIX (buffering): Simplify buffering logic for localstorage 2025-12-07 19:40:40 +03:00
Rostislav Dugin
f04a8b7a82 FIX (backup): Add double buffering for local storange 2025-12-07 19:02:44 +03:00
Rostislav Dugin
552167e4ef FIX (logos): Update logos 2025-12-07 18:46:39 +03:00
Rostislav Dugin
be42cfab1f Merge branch 'main' of https://github.com/RostislavDugin/postgresus 2025-12-07 17:50:05 +03:00
Rostislav Dugin
ea34ced676 Merge pull request #124 from akalitenya/helm-values-tag-fix
Set default helm chart image tag to null
2025-12-07 17:49:21 +03:00
Rostislav Dugin
09cb1488b3 FIX (notifications): Get rid of password validation for email 2025-12-07 17:48:11 +03:00
Rostislav Dugin
b6518ef667 FIX (buffers): Increase copy buffer size 2025-12-07 17:44:35 +03:00
akalitenya
25c58e6209 set default image tag to null 2025-12-07 10:34:18 +05:00
Rostislav Dugin
97ee4b55c2 FIX (helm): Use standard namespace behavior instead of hardcoded values 2025-12-04 19:59:19 +03:00
Rostislav Dugin
12eea72392 FEATURE (helm): Use ClusterIP by default and add deployment to ghcr.io 2025-12-04 15:11:09 +03:00
Rostislav Dugin
75c88bac50 FIX (webhook): Escape webhook characters 2025-12-04 14:28:49 +03:00
Rostislav Dugin
ff1b6536bf FIX (connection): Add standard_conforming_strings param when building string to connect to PG 2025-12-03 18:42:49 +03:00
Rostislav Dugin
06197f986d FIX (chunking): Add backuping chunk by chunk without buffering in RAM and improve cancelation process 2025-12-03 17:35:43 +03:00
Rostislav Dugin
fe72e9e0a6 FIX (healthcheck): Clean up healthcheck interval receving when tab changed 2025-12-03 08:08:49 +03:00
Rostislav Dugin
640cceadbd FIX (docs): Extend docs with HTTP route support 2025-12-03 07:43:00 +03:00
Rostislav Dugin
80e573fcb3 Merge pull request #121 from tylerobara/feature/add_httproute_support
FEATURE helm: Adding support for HTTPRoutes
2025-12-03 07:34:20 +03:00
Tyler Obara
35498d83f1 adding support for httperoutes 2025-12-02 17:01:38 -05:00
Rostislav Dugin
77ae8d1ac7 FIX (helm): Fix Helm path in readmes 2025-12-02 17:43:43 +03:00
Rostislav Dugin
2f20845b3d Merge branch 'main' of https://github.com/RostislavDugin/postgresus 2025-12-02 17:41:02 +03:00
Rostislav Dugin
a3d3df4093 FIX (zoom): Disable zoom on iOS 2025-12-02 17:40:43 +03:00
Rostislav Dugin
8db83d40d5 FIX (mobile): Do not preselect card on mobile for DBs, notifiers and storanges 2025-12-02 17:37:03 +03:00
Rostislav Dugin
065ded37bd Merge pull request #119 from tylerobara/fix/helm_liveness_readiness
FIX Helm: Templates, Liveness and Readiness probes
2025-12-02 17:15:50 +03:00
Tyler Obara
71e801debb change helm dir 2025-12-02 08:44:46 -05:00
Tyler Obara
ffd4e3a27b fixing liveness and readiness probes 2025-12-02 08:02:26 -05:00
Rostislav Dugin
d2a9085591 FIX (dump): Get rid of extra encoding param when backup and restore 2025-12-02 12:54:07 +03:00
Rostislav Dugin
6f0152b60c FIX (helm): Get rid of ingress by default 2025-12-02 10:03:47 +03:00
Rostislav Dugin
7007236f2f FIX (email): Recrate client in case of auth error 2025-12-02 09:43:49 +03:00
Rostislav Dugin
db55cad310 Merge pull request #116 from RostislavDugin/feature/helm_chart
FIX (helm): Add git clone step
2025-12-02 00:02:13 +03:00
Rostislav Dugin
25bd096c81 FIX (helm): Add git clone step 2025-12-01 23:57:05 +03:00
Rostislav Dugin
7e98dd578c Merge pull request #115 from RostislavDugin/feature/helm_chart
Feature/helm chart
2025-12-01 23:47:27 +03:00
Rostislav Dugin
ba37b30e83 FEATURE (helm): Add Helm chart installation 2025-12-01 23:47:00 +03:00
Rostislav Dugin
34b3f822e3 Merge pull request #114 from spa-skyson/helmchart
helmchart v1.0.0
2025-12-01 23:18:20 +03:00
Rostislav Dugin
14700130b7 FIX (email): Add login auth in case if plain fails 2025-12-01 23:16:54 +03:00
Alexander Gazal
de11ab8d8a helmchart v1.0.0 2025-12-01 08:47:17 +03:00
Rostislav Dugin
06282bb435 FIX (connection): Avoid usage of prepare statements to get rid of problem with PgBounder 2025-11-30 20:50:25 +03:00
Rostislav Dugin
a3b263bbac FIX (installation): Fix installation on Debian 2025-11-30 20:25:28 +03:00
Rostislav Dugin
a956dccf7c FIX (whitelist): Show hint about Postgresus whitelist in case of connection failure 2025-11-28 23:59:20 +03:00
Rostislav Dugin
ce9fa18d58 FEATURE (webhook): Add webhook customization 2025-11-28 21:53:44 +03:00
Rostislav Dugin
281e185f21 FIX (dark): Add dark theme image 2025-11-27 23:17:43 +03:00
Rostislav Dugin
bb5b0064ea Merge branch 'main' of https://github.com/RostislavDugin/postgresus 2025-11-27 22:19:34 +03:00
Rostislav Dugin
da95bbb178 FIX (s3): Do not allow to change prefix after creation 2025-11-27 22:00:21 +03:00
Rostislav Dugin
cfe5993831 Merge pull request #110 from RostislavDugin/feature/pgpass_escape
Feature/pgpass escape
2025-11-27 17:03:06 +03:00
Rostislav Dugin
fa0e3d1ce2 REFACTOR (pgpass): Refactor escaping 2025-11-27 17:00:26 +03:00
Rostislav Dugin
d07085c462 Merge pull request #108 from kapawit/fix/pgpass-special-characters
FIX (postgresql): Escape special characters in .pgpass file for authentication
2025-11-27 16:54:38 +03:00
kapawit
c89c1f9654 FIX (postgresql): Escape special characters in .pgpass file for authentication 2025-11-26 21:35:38 +07:00
Rostislav Dugin
6cfc0ca79b FEATURE (dark): Add dark theme 2025-11-26 00:07:23 +03:00
Rostislav Dugin
5d27123bd7 FEATURE (adaptivity): Add mobile adaptivity 2025-11-25 21:40:46 +03:00
Rostislav Dugin
79ca374bb6 FEATURE (notifiers): Add mobile adaptivity for notifiers 2025-11-23 23:43:58 +03:00
Rostislav Dugin
b3f1a6f7e5 FEATURE (databases): Add adaptivity for mobile databases 2025-11-23 20:23:05 +03:00
Rostislav Dugin
d521e2abc6 FIX (slack): Add request timeout for 30 seconds 2025-11-23 18:19:28 +03:00
Rostislav Dugin
82eca7501b FEATURE (security): Clean PostgreSQL creds after restore 2025-11-21 20:30:12 +03:00
Rostislav Dugin
51866437fd FEATURE (secutiry): Add read-only user creation before Postgresus backups 2025-11-21 19:14:13 +03:00
Rostislav Dugin
244a56d1bb FEATURE (secrets): Move secrets to the secret.key file instead of DB 2025-11-19 18:53:58 +03:00
Rostislav Dugin
95c833b619 FIX (backups): Fix passing encypted password to .pgpass 2025-11-19 17:10:19 +03:00
Rostislav Dugin
878fad5747 FEATURE (encryption): Add encyption for secrets in notifiers and storages 2025-11-18 21:23:59 +03:00
Rostislav Dugin
6ff3096695 FIX (password reset): Allow to change user password even if password was not set before 2025-11-17 20:20:31 +03:00
Rostislav Dugin
b4b514c2d5 FEATURE (encryption): Add backups encryption 2025-11-17 14:33:37 +03:00
Rostislav Dugin
da0fec6624 FEATURE (azure): Add Azure Blob Storage 2025-11-16 23:38:20 +03:00
Rostislav Dugin
408675023a FEATURE (s3): Add support of virtual-styled-domains and S3 prefix 2025-11-16 11:22:03 +03:00
Rostislav Dugin
0bc93389cc FEATURE (backups): Include workspace name in notification about success or fail 2025-11-15 11:40:42 +03:00
Rostislav Dugin
c8e6aea6e1 FEATURE (hints): Add hints about localhost connection 2025-11-15 00:25:51 +03:00
Rostislav Dugin
981ad21471 FEATURE (email): Add "to" header to email 2025-11-14 20:39:02 +03:00
Rostislav Dugin
177a9c782c Revert "FIX (notifiers): Improve email validation"
This reverts commit 02c735bc5a.
2025-11-14 20:35:22 +03:00
Rostislav Dugin
069d6bc8fe FEATURE (logo): Update logo 2025-11-14 20:19:26 +03:00
Rostislav Dugin
242d5543d4 FIX (backups): Avoid possibility of breaking DB on backup fail 2025-11-14 19:56:56 +03:00
Rostislav Dugin
02c735bc5a FIX (notifiers): Improve email validation 2025-11-14 18:02:27 +03:00
Rostislav Dugin
793b575146 FIX (storages): Ignore files removal errors for unavailable storage when deleting the database 2025-11-14 18:02:13 +03:00
Rostislav Dugin
a6e84b45f2 Merge pull request #84 from RostislavDugin/feature/add_pg_12
Feature/add pg 12
2025-11-12 15:43:09 +03:00
Rostislav Dugin
a941fbd093 FEATURE (postgres): Add PostgreSQL 12 tests and CI \ CD config 2025-11-12 15:39:44 +03:00
Rostislav Dugin
4492ba41f5 Merge pull request #82 from romanesko/feature/v12-support
feat: add PostgreSQL 12 support
2025-11-12 15:04:12 +03:00
Roman Bykovsky
3a5ac4b479 feat: add PostgreSQL 12 support 2025-11-11 18:53:26 +03:00
Rostislav Dugin
77aaabeaa1 FEATURE (docs): Update readme and docs links 2025-11-11 16:56:33 +03:00
Rostislav Dugin
01911dbf72 FIX (notifiers & storages): Avoid request for workspace_id for storages and notifiers removal 2025-11-11 10:05:45 +03:00
Rostislav Dugin
1a16f27a5d FIX (notifiers): Fix update of existing DB notifiers 2025-11-11 08:10:02 +03:00
Rostislav Dugin
778db71625 FIX (tests): Improve tests stability in CI \ CD 2025-11-09 20:41:36 +03:00
Rostislav Dugin
45fc9a7fff FIX (databases): Verify DB nil on side of DB instead of interface 2025-11-09 20:03:22 +03:00
Rostislav Dugin
7f5e786261 FIX (databases): If some DB missing PostgreSQL db fix nil issue 2025-11-09 18:57:42 +03:00
Rostislav Dugin
9b066bcb8a FEATURE (email): Add "from" field 2025-11-08 20:47:35 +03:00
250 changed files with 19769 additions and 4467 deletions

102
.github/CODE_OF_CONDUCT.md vendored Normal file
View File

@@ -0,0 +1,102 @@
# Code of Conduct
## Our Pledge
We as members, contributors and maintainers pledge to make participation in the Postgresus community a friendly and welcoming experience for everyone, regardless of background, experience level or personal circumstances.
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive and healthy community.
## Our Standards
### Examples of behavior that contributes to a positive environment
- Using welcoming and inclusive language
- Being respectful of differing viewpoints and experiences
- Gracefully accepting constructive criticism
- Focusing on what is best for the community
- Showing empathy towards other community members
- Helping newcomers get started with contributions
- Providing clear and constructive feedback on pull requests
- Celebrating successes and acknowledging contributions
### Examples of unacceptable behavior
- Trolling, insulting or derogatory comments, and personal or political attacks
- Publishing others' private information, such as physical or email addresses, without their explicit permission
- Spam, self-promotion or off-topic content in project spaces
- Other conduct which could reasonably be considered inappropriate in a professional setting
## Scope
This Code of Conduct applies within all community spaces, including:
- GitHub repositories (issues, pull requests, discussions, comments)
- Telegram channels and direct messages related to Postgresus
- Social media interactions when representing the project
- Community forums and online discussions
- Any other spaces where Postgresus community members interact
This Code of Conduct also applies when an individual is officially representing the community in public spaces, such as using an official email address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
## Enforcement
Instances of abusive or unacceptable behavior may be reported to the community leaders responsible for enforcement:
- **Email**: [info@postgresus.com](mailto:info@postgresus.com)
- **Telegram**: [@rostislav_dugin](https://t.me/rostislav_dugin)
All complaints will be reviewed and investigated promptly and fairly.
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
## Enforcement Guidelines
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
### 1. Correction
**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
### 2. Warning
**Community Impact**: A violation through a single incident or series of actions.
**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
### 3. Temporary Ban
**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
### 4. Permanent Ban
**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the community.
## Contributing with Respect
When contributing to Postgresus, please:
- Be patient with maintainers and other contributors
- Understand that everyone has different levels of experience
- Ask questions in a respectful manner
- Accept that your contribution may not be accepted, and be open to feedback
- Follow the [contribution guidelines](https://postgresus.com/contribute)
For code contributions, remember to:
- Discuss significant changes before implementing them
- Be open to code review feedback
- Help review others' contributions when possible
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html).
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq).

54
.github/SECURITY.md vendored Normal file
View File

@@ -0,0 +1,54 @@
# Security Policy
## Reporting a Vulnerability
If you discover a security vulnerability in Postgresus, please report it responsibly. **Do not create a public GitHub issue for security vulnerabilities.**
### How to Report
1. **Email** (preferred): Send details to [info@postgresus.com](mailto:info@postgresus.com)
2. **Telegram**: Contact [@rostislav_dugin](https://t.me/rostislav_dugin)
3. **GitHub Security Advisories**: Use the [private vulnerability reporting](https://github.com/RostislavDugin/postgresus/security/advisories/new) feature
### What to Include
- Description of the vulnerability
- Steps to reproduce the issue
- Potential impact and severity assessment
- Any suggested fixes (optional)
## Supported Versions
| Version | Supported |
| ------- | --------- |
| Latest | Yes |
We recommend always using the latest version of Postgresus. Security patches are applied to the most recent release.
### PostgreSQL Compatibility
Postgresus supports PostgreSQL versions 12, 13, 14, 15, 16, 17 and 18.
## Response Timeline
- **Acknowledgment**: Within 48-72 hours
- **Initial Assessment**: Within 1 week
- **Fix Timeline**: Depends on severity, but we aim to address critical issues as quickly as possible
We follow a coordinated disclosure policy. We ask that you give us reasonable time to address the vulnerability before any public disclosure.
## Security Features
Postgresus is designed with security in mind. For full details, see our [security documentation](https://postgresus.com/security).
Key features include:
- **AES-256-GCM Encryption**: Enterprise-grade encryption for backup files and sensitive data
- **Read-Only Database Access**: Postgresus uses read-only access by default and warns if write permissions are detected
- **Role-Based Access Control**: Assign viewer, member, admin or owner roles within workspaces
- **Audit Logging**: Track all system activities and changes made by users
- **Zero-Trust Storage**: Encrypted backups are safe even in shared cloud storage
## License
Postgresus is licensed under [Apache 2.0](../LICENSE).

View File

@@ -17,7 +17,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23.3"
go-version: "1.24.4"
- name: Cache Go modules
uses: actions/cache@v4
@@ -31,7 +31,7 @@ jobs:
- name: Install golangci-lint
run: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.60.3
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
- name: Install swag for swagger generation
@@ -82,6 +82,30 @@ jobs:
cd frontend
npm run lint
test-frontend:
runs-on: ubuntu-latest
needs: [lint-frontend]
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
cd frontend
npm ci
- name: Run frontend tests
run: |
cd frontend
npm run test
test-backend:
runs-on: ubuntu-latest
needs: [lint-backend]
@@ -92,7 +116,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.23.3"
go-version: "1.24.4"
- name: Cache Go modules
uses: actions/cache@v4
@@ -127,6 +151,7 @@ jobs:
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
TEST_GOOGLE_DRIVE_TOKEN_JSON=${{ secrets.TEST_GOOGLE_DRIVE_TOKEN_JSON }}
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -136,11 +161,23 @@ jobs:
# testing S3
TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing Azure Blob
TEST_AZURITE_BLOB_PORT=10000
# testing NAS
TEST_NAS_PORT=7006
# testing FTP
TEST_FTP_PORT=7007
# testing SFTP
TEST_SFTP_PORT=7008
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
# supabase
TEST_SUPABASE_HOST=${{ secrets.TEST_SUPABASE_HOST }}
TEST_SUPABASE_PORT=${{ secrets.TEST_SUPABASE_PORT }}
TEST_SUPABASE_USERNAME=${{ secrets.TEST_SUPABASE_USERNAME }}
TEST_SUPABASE_PASSWORD=${{ secrets.TEST_SUPABASE_PASSWORD }}
TEST_SUPABASE_DATABASE=${{ secrets.TEST_SUPABASE_DATABASE }}
EOF
- name: Start test containers
@@ -154,6 +191,7 @@ jobs:
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
@@ -163,6 +201,15 @@ jobs:
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
# Wait for FTP
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
# Wait for SFTP
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
- name: Create data and temp directories
run: |
# Create directories that are used for backups and restore
@@ -185,7 +232,7 @@ jobs:
- name: Run Go tests
run: |
cd backend
go test ./internal/...
go test -p=1 -count=1 -failfast ./internal/...
- name: Stop test containers
if: always()
@@ -195,7 +242,7 @@ jobs:
determine-version:
runs-on: ubuntu-latest
needs: [test-backend, lint-frontend]
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
outputs:
should_release: ${{ steps.version_bump.outputs.should_release }}
@@ -288,7 +335,7 @@ jobs:
build-only:
runs-on: ubuntu-latest
needs: [test-backend, lint-frontend]
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
steps:
- name: Check out code
@@ -448,6 +495,17 @@ jobs:
echo EOF
} >> $GITHUB_OUTPUT
- name: Update CITATION.cff version
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
git add CITATION.cff
git commit -m "Update CITATION.cff to v${VERSION}" || true
git push || true
- name: Create GitHub Release
uses: actions/create-release@v1
env:
@@ -458,3 +516,37 @@ jobs:
body: ${{ steps.changelog.outputs.changelog }}
draft: false
prerelease: false
publish-helm-chart:
runs-on: ubuntu-latest
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: read
packages: write
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Helm
uses: azure/setup-helm@v4
with:
version: v3.14.0
- name: Log in to GHCR
run: echo "${{ secrets.GITHUB_TOKEN }}" | helm registry login ghcr.io -u ${{ github.actor }} --password-stdin
- name: Update Chart.yaml with release version
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
sed -i "s/^version: .*/version: ${VERSION}/" deploy/helm/Chart.yaml
sed -i "s/^appVersion: .*/appVersion: \"v${VERSION}\"/" deploy/helm/Chart.yaml
cat deploy/helm/Chart.yaml
- name: Package Helm chart
run: helm package deploy/helm --destination .
- name: Push Helm chart to GHCR
run: |
VERSION="${{ needs.determine-version.outputs.new_version }}"
helm push postgresus-${VERSION}.tgz oci://ghcr.io/rostislavdugin/charts

6
.gitignore vendored
View File

@@ -4,4 +4,8 @@ postgresus-data/
pgdata/
docker-compose.yml
node_modules/
.idea
.idea
/articles
.DS_Store
/scripts

33
CITATION.cff Normal file
View File

@@ -0,0 +1,33 @@
cff-version: 1.2.0
title: Postgresus
message: "If you use this software, please cite it as below."
type: software
authors:
- family-names: Dugin
given-names: Rostislav
repository-code: https://github.com/RostislavDugin/postgresus
url: https://postgresus.com
abstract: "Free, open source and self-hosted solution for automated PostgreSQL backups with multiple storage options and notifications."
keywords:
- docker
- kubernetes
- golang
- backups
- postgres
- devops
- backup
- database
- tools
- monitoring
- ftp
- postgresql
- s3
- psql
- web-ui
- self-hosted
- pg
- system-administration
- database-backup
license: Apache-2.0
version: 2.8.1
date-released: "2025-12-19"

View File

@@ -22,7 +22,7 @@ RUN npm run build
# ========= BUILD BACKEND =========
# Backend build stage
FROM --platform=$BUILDPLATFORM golang:1.23.3 AS backend-build
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
# Make TARGET args available early so tools built here match the final image arch
ARG TARGETOS
@@ -77,16 +77,16 @@ ENV APP_VERSION=$APP_VERSION
# Set production mode for Docker containers
ENV ENV_MODE=production
# Install PostgreSQL server and client tools (versions 13-17)
# Install PostgreSQL server and client tools (versions 12-18) and rclone
RUN apt-get update && apt-get install -y --no-install-recommends \
wget ca-certificates gnupg lsb-release sudo gosu && \
wget ca-certificates gnupg lsb-release sudo gosu curl unzip && \
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \
> /etc/apt/sources.list.d/pgdg.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
postgresql-17 postgresql-18 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-client-16 postgresql-client-17 postgresql-client-18 rclone && \
rm -rf /var/lib/apt/lists/*
# Create postgres user and set up directories

View File

@@ -9,7 +9,7 @@
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/RostislavDugin/postgresus)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-13%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-12%20%7C%2013%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/RostislavDugin/postgresus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/RostislavDugin/postgresus)
@@ -25,6 +25,8 @@
<a href="https://postgresus.com" target="_blank"><strong>🌐 Postgresus website</strong></a>
</p>
<img src="assets/dashboard-dark.svg" alt="Postgresus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
<img src="assets/dashboard.svg" alt="Postgresus Dashboard" width="800"/>
@@ -40,13 +42,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗄️ **Multiple Storage Destinations**
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
- **Local storage**: Keep backups on your VPS/server
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
- **Secure**: All data stays under your control
### 📱 **Smart Notifications**
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
- **Real-time updates**: Success and failure notifications
@@ -54,19 +56,48 @@
### 🐘 **PostgreSQL Support**
- **Multiple versions**: PostgreSQL 13, 14, 15, 16, 17 and 18
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
- **SSL support**: Secure connections available
- **Easy restoration**: One-click restore from any backup
### 🔒 **Enterprise-grade security** <a href="https://postgresus.com/security">(docs)</a>
- **AES-256-GCM encryption**: Enterprise-grade protection for backup files
- **Zero-trust storage**: Backups are encrypted and they are useless to attackers, so you can keep them in shared storages like S3, Azure Blob Storage, etc.
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
- **Read-only user**: Postgresus uses by default a read-only user for backups and never stores anything that can change your data
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
- **Access management**: Control who can view or manage specific databases with role-based permissions
- **Audit logs**: Track all system activities and changes made by users
- **User roles**: Assign viewer, member, admin or owner roles within workspaces
### 🎨 **UX-Friendly**
- **Designer-polished UI**: Clean, intuitive interface crafted with attention to detail
- **Dark & light themes**: Choose the look that suits your workflow
- **Mobile adaptive**: Check your backups from anywhere on any device
### ☁️ **Works with Self-Hosted & Cloud Databases**
Postgresus works seamlessly with both self-hosted PostgreSQL and cloud-managed databases:
- **Cloud support**: AWS RDS, Google Cloud SQL, Azure Database for PostgreSQL
- **Self-hosted**: Any PostgreSQL instance you manage yourself
- **Why no PITR?**: Cloud providers already offer native PITR, and external PITR backups cannot be restored to managed cloud databases — making them impractical for cloud-hosted PostgreSQL
- **Practical granularity**: Hourly and daily backups are sufficient for 99% of projects without the operational complexity of WAL archiving
### 🐳 **Self-Hosted & Secure**
- **Docker-based**: Easy deployment and management
- **Privacy-first**: All your data stays on your infrastructure
- **Open source**: Apache 2.0 licensed, inspect every line of code
### 📦 Installation
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
You have three ways to install Postgresus:
You have several ways to install Postgresus:
- Script (recommended)
- Simple Docker run
@@ -84,7 +115,7 @@ You have three ways to install Postgresus: automated script (recommended), simpl
The installation script will:
- ✅ Install Docker with Docker Compose(if not already installed)
- ✅ Install Docker with Docker Compose (if not already installed)
- ✅ Set up Postgresus
- ✅ Configure automatic startup on system reboot
@@ -118,8 +149,6 @@ This single command will:
Create a `docker-compose.yml` file with the following configuration:
```yaml
version: "3"
services:
postgresus:
container_name: postgresus
@@ -137,6 +166,46 @@ Then run:
docker compose up -d
```
### Option 4: Kubernetes with Helm
For Kubernetes deployments, install directly from the OCI registry.
**With ClusterIP + port-forward (development/testing):**
```bash
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
-n postgresus --create-namespace
```
```bash
kubectl port-forward svc/postgresus-service 4005:4005 -n postgresus
# Access at http://localhost:4005
```
**With LoadBalancer (cloud environments):**
```bash
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
-n postgresus --create-namespace \
--set service.type=LoadBalancer
```
```bash
kubectl get svc postgresus-service -n postgresus
# Access at http://<EXTERNAL-IP>:4005
```
**With Ingress (domain-based access):**
```bash
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
-n postgresus --create-namespace \
--set ingress.enabled=true \
--set ingress.hosts[0].host=backup.example.com
```
For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart README](deploy/helm/README.md).
---
## 🚀 Usage
@@ -149,9 +218,9 @@ docker compose up -d
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
### 🔑 Resetting Admin Password
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
If you need to reset the admin password, you can use the built-in password reset command:
If you need to reset the password, you can use the built-in password reset command:
```bash
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123" --email="admin"
@@ -163,10 +232,10 @@ Replace `admin` with the actual email address of the user whose password you wan
## 📝 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
---
## 🤝 Contributing
Contributions are welcome! Read [contributing guide](contribute/README.md) for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
Contributions are welcome! Read <a href="https://postgresus.com/contribute">contributing guide</a> for more details, priorities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)

764
assets/dashboard-dark.svg Normal file

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 766 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 791 KiB

After

Width:  |  Height:  |  Size: 771 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 13 KiB

View File

@@ -9,4 +9,4 @@ When applying changes, do not forget to refactor old code.
You can shortify, make more readable, improve code quality, etc.
Common logic can be extracted to functions, constants, files, etc.
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder).
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).

View File

@@ -17,6 +17,7 @@ TEST_GOOGLE_DRIVE_CLIENT_ID=
TEST_GOOGLE_DRIVE_CLIENT_SECRET=
TEST_GOOGLE_DRIVE_TOKEN_JSON="{\"access_token\":\"ya29..."
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -30,4 +31,16 @@ TEST_MINIO_CONSOLE_PORT=9001
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=
TEST_TELEGRAM_CHAT_ID=
TEST_TELEGRAM_CHAT_ID=
# testing Azure Blob Storage
TEST_AZURITE_BLOB_PORT=10000
# supabase
TEST_SUPABASE_HOST=
TEST_SUPABASE_PORT=
TEST_SUPABASE_USERNAME=
TEST_SUPABASE_PASSWORD=
TEST_SUPABASE_DATABASE=
# FTP
TEST_FTP_PORT=7007
# SFTP
TEST_SFTP_PORT=7008

View File

@@ -1,7 +1,7 @@
version: "2"
run:
timeout: 1m
timeout: 5m
tests: false
concurrency: 4

View File

@@ -18,6 +18,7 @@ import (
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/disk"
"postgresus-backend/internal/features/encryption/secrets"
healthcheck_attempt "postgresus-backend/internal/features/healthcheck/attempt"
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
@@ -64,6 +65,12 @@ func main() {
os.Exit(1)
}
err = secrets.GetSecretKeyService().MigrateKeyFromDbToFileIfExist()
if err != nil {
log.Error("Failed to migrate secret key from database to file", "error", err)
os.Exit(1)
}
err = users_services.GetUserService().CreateInitialAdmin()
if err != nil {
log.Error("Failed to create initial admin", "error", err)

View File

@@ -31,7 +31,26 @@ services:
container_name: test-minio
command: server /data --console-address ":9001"
# Test Azurite container
test-azurite:
image: mcr.microsoft.com/azure-storage/azurite
ports:
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
container_name: test-azurite
command: azurite-blob --blobHost 0.0.0.0
# Test PostgreSQL containers
test-postgres-12:
image: postgres:12
ports:
- "${TEST_POSTGRES_12_PORT}:5432"
environment:
- POSTGRES_DB=testdb
- POSTGRES_USER=testuser
- POSTGRES_PASSWORD=testpassword
container_name: test-postgres-12
shm_size: 1gb
test-postgres-13:
image: postgres:13
ports:
@@ -113,3 +132,25 @@ services:
-s "backups;/shared;yes;no;no;testuser"
-p
container_name: test-nas
# Test FTP server
test-ftp:
image: stilliard/pure-ftpd:latest
ports:
- "${TEST_FTP_PORT:-21}:21"
- "30000-30009:30000-30009"
environment:
- PUBLICHOST=localhost
- FTP_USER_NAME=testuser
- FTP_USER_PASS=testpassword
- FTP_USER_HOME=/home/ftpusers/testuser
- FTP_PASSIVE_PORTS=30000:30009
container_name: test-ftp
# Test SFTP server
test-sftp:
image: atmoz/sftp:latest
ports:
- "${TEST_SFTP_PORT:-7008}:22"
command: testuser:testpassword:1001::upload
container_name: test-sftp

View File

@@ -1,8 +1,10 @@
module postgresus-backend
go 1.23.3
go 1.24.4
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -10,33 +12,195 @@ require (
github.com/google/uuid v1.6.0
github.com/ilyakaznacheev/cleanenv v1.5.0
github.com/jackc/pgx/v5 v5.7.5
github.com/jlaffaye/ftp v0.2.1-0.20240918233326-1b970516f5d3
github.com/jmoiron/sqlx v1.4.0
github.com/joho/godotenv v1.5.1
github.com/lib/pq v1.10.9
github.com/minio/minio-go/v7 v7.0.92
github.com/shirou/gopsutil/v4 v4.25.5
github.com/stretchr/testify v1.10.0
github.com/minio/minio-go/v7 v7.0.97
github.com/pkg/sftp v1.13.10
github.com/rclone/rclone v1.72.1
github.com/shirou/gopsutil/v4 v4.25.10
github.com/stretchr/testify v1.11.1
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
golang.org/x/crypto v0.39.0
golang.org/x/time v0.12.0
golang.org/x/crypto v0.46.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
require (
cloud.google.com/go/auth v0.16.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
github.com/Azure/azure-sdk-for-go/sdk/storage/azfile v1.5.3 // indirect
github.com/Azure/go-ntlmssp v0.0.2-0.20251110135918-10b7b7e7cd26 // indirect
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
github.com/Files-com/files-sdk-go/v3 v3.2.264 // indirect
github.com/IBM/go-sdk-core/v5 v5.21.0 // indirect
github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
github.com/ProtonMail/gluon v0.17.1-0.20230724134000-308be39be96e // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f // indirect
github.com/ProtonMail/go-srp v0.0.7 // indirect
github.com/ProtonMail/gopenpgp/v2 v2.9.0 // indirect
github.com/PuerkitoBio/goquery v1.10.3 // indirect
github.com/a1ex3/zstd-seekable-format-go/pkg v0.10.0 // indirect
github.com/abbot/go-http-auth v0.4.0 // indirect
github.com/anchore/go-lzo v0.1.0 // indirect
github.com/andybalholm/cascadia v1.3.3 // indirect
github.com/appscode/go-querystring v0.0.0-20170504095604-0126cfb3f1dc // indirect
github.com/aws/aws-sdk-go-v2 v1.39.6 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 // indirect
github.com/aws/aws-sdk-go-v2/config v1.31.17 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.18.21 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 // indirect
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.20.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.13 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect
github.com/aws/aws-sdk-go-v2/service/s3 v1.90.0 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect
github.com/aws/smithy-go v1.23.2 // indirect
github.com/bahlo/generic-list-go v0.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/boombuler/barcode v1.1.0 // indirect
github.com/bradenaw/juniper v0.15.3 // indirect
github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect
github.com/buengese/sgzip v0.1.1 // indirect
github.com/buger/jsonparser v1.1.1 // indirect
github.com/calebcase/tmpfile v1.0.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/chilts/sid v0.0.0-20190607042430-660e94789ec9 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudinary/cloudinary-go/v2 v2.13.0 // indirect
github.com/cloudsoda/go-smb2 v0.0.0-20250228001242-d4c70e6251cc // indirect
github.com/cloudsoda/sddl v0.0.0-20250224235906-926454e91efc // indirect
github.com/colinmarc/hdfs/v2 v2.4.0 // indirect
github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.6.0 // indirect
github.com/creasty/defaults v1.8.0 // indirect
github.com/cronokirby/saferith v0.33.0 // indirect
github.com/diskfs/go-diskfs v1.7.0 // indirect
github.com/dropbox/dropbox-sdk-go-unofficial/v6 v6.0.5 // indirect
github.com/emersion/go-message v0.18.2 // indirect
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
github.com/flynn/noise v1.1.0 // indirect
github.com/go-chi/chi/v5 v5.2.3 // indirect
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
github.com/go-git/go-billy/v5 v5.6.2 // indirect
github.com/go-openapi/errors v0.22.4 // indirect
github.com/go-openapi/strfmt v0.25.0 // indirect
github.com/go-resty/resty/v2 v2.16.5 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/gofrs/flock v0.13.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/gorilla/schema v1.4.1 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
github.com/hashicorp/go-uuid v1.0.3 // indirect
github.com/henrybear327/Proton-API-Bridge v1.0.0 // indirect
github.com/henrybear327/go-proton-api v1.0.0 // indirect
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
github.com/jcmturner/gofork v1.7.6 // indirect
github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 // indirect
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 // indirect
github.com/klauspost/crc32 v1.3.0 // indirect
github.com/koofr/go-httpclient v0.0.0-20240520111329-e20f8f203988 // indirect
github.com/koofr/go-koofrclient v0.0.0-20221207135200-cbd7fc9ad6a6 // indirect
github.com/kr/fs v0.1.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lanrat/extsort v1.4.2 // indirect
github.com/lpar/date v1.0.0 // indirect
github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/ncw/swift/v2 v2.0.5 // indirect
github.com/oklog/ulid v1.3.1 // indirect
github.com/oracle/oci-go-sdk/v65 v65.104.0 // indirect
github.com/panjf2000/ants/v2 v2.11.3 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/pengsrc/go-shared v0.2.1-0.20190131101655-1999055a4a14 // indirect
github.com/peterh/liner v1.2.2 // indirect
github.com/pierrec/lz4/v4 v4.1.22 // indirect
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pkg/xattr v0.4.12 // indirect
github.com/pquerna/otp v1.5.0 // indirect
github.com/prometheus/client_golang v1.23.2 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.2 // indirect
github.com/prometheus/procfs v0.19.2 // indirect
github.com/putdotio/go-putio/putio v0.0.0-20200123120452-16d982cac2b8 // indirect
github.com/relvacode/iso8601 v1.7.0 // indirect
github.com/rfjakob/eme v1.1.2 // indirect
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // indirect
github.com/samber/lo v1.52.0 // indirect
github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af // indirect
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
github.com/sony/gobreaker v1.0.0 // indirect
github.com/spacemonkeygo/monkit/v3 v3.0.25-0.20251022131615-eb24eb109368 // indirect
github.com/spf13/pflag v1.0.10 // indirect
github.com/t3rm1n4l/go-mega v0.0.0-20251031123324-a804aaa87491 // indirect
github.com/tklauser/go-sysconf v0.3.15 // indirect
github.com/tklauser/numcpus v0.10.0 // indirect
github.com/ulikunitz/xz v0.5.15 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
github.com/yunify/qingstor-sdk-go/v3 v3.2.0 // indirect
github.com/zeebo/blake3 v0.2.4 // indirect
github.com/zeebo/errs v1.4.0 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.etcd.io/bbolt v1.4.3 // indirect
go.mongodb.org/mongo-driver v1.17.6 // indirect
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/term v0.38.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/validator.v2 v2.0.1 // indirect
moul.io/http2curl/v2 v2.3.0 // indirect
sigs.k8s.io/yaml v1.6.0 // indirect
storj.io/common v0.0.0-20251107171817-6221ae45072c // indirect
storj.io/drpc v0.0.35-0.20250513201419-f7819ea69b55 // indirect
storj.io/eventkit v0.0.0-20250410172343-61f26d3de156 // indirect
storj.io/infectious v0.0.2 // indirect
storj.io/picobuf v0.0.4 // indirect
storj.io/uplink v1.13.1 // indirect
)
require (
cloud.google.com/go/auth v0.17.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
cloud.google.com/go/compute/metadata v0.7.0 // indirect
github.com/geoffgarside/ber v1.1.0 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
github.com/geoffgarside/ber v1.2.0 // indirect
github.com/google/s2a-go v0.1.9 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
github.com/googleapis/gax-go/v2 v2.14.2 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
github.com/hirochachacha/go-smb2 v1.1.0
google.golang.org/genproto/googleapis/api v0.0.0-20250528174236-200df99c418a // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
google.golang.org/grpc v1.73.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
google.golang.org/grpc v1.76.0 // indirect
)
require (
@@ -47,11 +211,11 @@ require (
github.com/bytedance/sonic v1.13.2 // indirect
github.com/bytedance/sonic/loader v0.2.4 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/ebitengine/purego v0.9.1 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
github.com/gabriel-vasile/mimetype v1.4.11 // indirect
github.com/gin-contrib/sse v1.1.0 // indirect
github.com/go-ini/ini v1.67.0 // indirect
github.com/go-logr/logr v1.4.3 // indirect
@@ -63,7 +227,7 @@ require (
github.com/go-openapi/swag v0.19.15 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.26.0 // indirect
github.com/go-playground/validator/v10 v10.28.0 // indirect
github.com/go-sql-driver/mysql v1.9.2 // indirect
github.com/goccy/go-json v0.10.5 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
@@ -73,40 +237,39 @@ require (
github.com/jinzhu/now v1.1.5 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.18.0 // indirect
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/klauspost/compress v1.18.1 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mailru/easyjson v0.7.6 // indirect
github.com/mailru/easyjson v0.9.1 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/minio/crc64nvme v1.0.1 // indirect
github.com/minio/crc64nvme v1.1.1 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/rogpeppe/go-internal v1.14.1 // indirect
github.com/rs/xid v1.6.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tinylib/msgp v1.3.0 // indirect
github.com/tinylib/msgp v1.5.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
go.opentelemetry.io/otel v1.36.0 // indirect
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
go.opentelemetry.io/otel v1.38.0 // indirect
go.opentelemetry.io/otel/metric v1.38.0 // indirect
go.opentelemetry.io/otel/trace v1.38.0 // indirect
golang.org/x/arch v0.17.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/tools v0.33.0 // indirect
google.golang.org/api v0.239.0
google.golang.org/protobuf v1.36.6 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/oauth2 v0.33.0
golang.org/x/sync v0.19.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
golang.org/x/tools v0.39.0 // indirect
google.golang.org/api v0.255.0
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect

File diff suppressed because it is too large Load Diff

View File

@@ -26,13 +26,15 @@ type EnvVariables struct {
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
DataFolder string
TempFolder string
DataFolder string
TempFolder string
SecretKeyPath string
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
TestPostgres12Port string `env:"TEST_POSTGRES_12_PORT"`
TestPostgres13Port string `env:"TEST_POSTGRES_13_PORT"`
TestPostgres14Port string `env:"TEST_POSTGRES_14_PORT"`
TestPostgres15Port string `env:"TEST_POSTGRES_15_PORT"`
@@ -43,7 +45,11 @@ type EnvVariables struct {
TestMinioPort string `env:"TEST_MINIO_PORT"`
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
TestNASPort string `env:"TEST_NAS_PORT"`
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
TestNASPort string `env:"TEST_NAS_PORT"`
TestFTPPort string `env:"TEST_FTP_PORT"`
TestSFTPPort string `env:"TEST_SFTP_PORT"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
@@ -54,6 +60,13 @@ type EnvVariables struct {
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
// testing Supabase
TestSupabaseHost string `env:"TEST_SUPABASE_HOST"`
TestSupabasePort string `env:"TEST_SUPABASE_PORT"`
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
}
var (
@@ -143,8 +156,13 @@ func loadEnvVariables() {
// (projectRoot/postgresus-data -> /postgresus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "backups")
env.TempFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "temp")
env.SecretKeyPath = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "secret.key")
if env.IsTesting {
if env.TestPostgres12Port == "" {
log.Error("TEST_POSTGRES_12_PORT is empty")
os.Exit(1)
}
if env.TestPostgres13Port == "" {
log.Error("TEST_POSTGRES_13_PORT is empty")
os.Exit(1)
@@ -179,6 +197,11 @@ func loadEnvVariables() {
os.Exit(1)
}
if env.TestAzuriteBlobPort == "" {
log.Error("TEST_AZURITE_BLOB_PORT is empty")
os.Exit(1)
}
if env.TestNASPort == "" {
log.Error("TEST_NAS_PORT is empty")
os.Exit(1)

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/config"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/period"
"time"
)
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
continue
}
err = storage.DeleteFile(backup.ID)
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}

View File

@@ -44,11 +44,8 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// add old backup
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -105,11 +102,8 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -169,11 +163,8 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -234,11 +225,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -262,7 +250,7 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
@@ -300,11 +288,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,

View File

@@ -2,20 +2,21 @@ package backups
import (
"context"
"errors"
"sync"
"github.com/google/uuid"
)
type BackupContextManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
cancelledBackups map[uuid.UUID]bool
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
cancelledBackups: make(map[uuid.UUID]bool),
}
}
@@ -23,25 +24,37 @@ func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc con
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
delete(m.cancelledBackups, backupID)
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if !exists {
return errors.New("backup is not in progress or already completed")
if m.cancelledBackups[backupID] {
return nil
}
cancelFunc()
delete(m.cancelFuncs, backupID)
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
}
m.cancelledBackups[backupID] = true
return nil
}
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
m.mu.RLock()
defer m.mu.RUnlock()
return m.cancelledBackups[backupID]
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
delete(m.cancelledBackups, backupID)
}

View File

@@ -1,6 +1,7 @@
package backups
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -26,6 +27,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -512,9 +514,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
Database: database,
StorageID: storage.ID,
Storage: storage,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
@@ -526,7 +526,7 @@ func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextMgr.RegisterBackup(backup.ID, func() {})
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
@@ -686,9 +686,7 @@ func createTestBackup(
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
Database: database,
StorageID: storages[0].ID,
Storage: storages[0],
Status: BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
@@ -704,7 +702,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -1,15 +1,18 @@
package backups
import (
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/usecases"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"time"
)
var backupRepository = &BackupRepository{}
@@ -23,6 +26,8 @@ var backupService = &BackupService{
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},

View File

@@ -1,5 +1,10 @@
package backups
import (
"io"
"postgresus-backend/internal/features/backups/backups/encryption"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
@@ -12,3 +17,12 @@ type GetBackupsResponse struct {
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type decryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
}

View File

@@ -0,0 +1,156 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type DecryptionReader struct {
baseReader io.Reader
cipher cipher.AEAD
buffer []byte
nonce []byte
chunkIndex uint64
headerRead bool
eof bool
}
func NewDecryptionReader(
baseReader io.Reader,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*DecryptionReader, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
reader := &DecryptionReader{
baseReader,
aesgcm,
make([]byte, 0),
nonce,
0,
false,
false,
}
if err := reader.readAndValidateHeader(salt, nonce); err != nil {
return nil, err
}
return reader, nil
}
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
for len(r.buffer) < len(p) && !r.eof {
if err := r.readAndDecryptChunk(); err != nil {
if err == io.EOF {
r.eof = true
break
}
return 0, err
}
}
if len(r.buffer) == 0 {
return 0, io.EOF
}
n = copy(p, r.buffer)
r.buffer = r.buffer[n:]
return n, nil
}
func (r *DecryptionReader) readAndValidateHeader(expectedSalt, expectedNonce []byte) error {
header := make([]byte, HeaderLen)
if _, err := io.ReadFull(r.baseReader, header); err != nil {
return fmt.Errorf("failed to read header: %w", err)
}
magic := string(header[0:MagicBytesLen])
if magic != MagicBytes {
return fmt.Errorf("invalid magic bytes: expected %s, got %s", MagicBytes, magic)
}
salt := header[MagicBytesLen : MagicBytesLen+SaltLen]
nonce := header[MagicBytesLen+SaltLen : MagicBytesLen+SaltLen+NonceLen]
if string(salt) != string(expectedSalt) {
return fmt.Errorf("salt mismatch in file header")
}
if string(nonce) != string(expectedNonce) {
return fmt.Errorf("nonce mismatch in file header")
}
r.headerRead = true
return nil
}
func (r *DecryptionReader) readAndDecryptChunk() error {
lengthBuf := make([]byte, 4)
if _, err := io.ReadFull(r.baseReader, lengthBuf); err != nil {
return err
}
chunkLen := binary.BigEndian.Uint32(lengthBuf)
if chunkLen == 0 || chunkLen > ChunkSize+16 {
return fmt.Errorf("invalid chunk length: %d", chunkLen)
}
encrypted := make([]byte, chunkLen)
if _, err := io.ReadFull(r.baseReader, encrypted); err != nil {
return fmt.Errorf("failed to read encrypted chunk: %w", err)
}
chunkNonce := r.generateChunkNonce()
decrypted, err := r.cipher.Open(nil, chunkNonce, encrypted, nil)
if err != nil {
return fmt.Errorf(
"failed to decrypt chunk (authentication failed - file may be corrupted or tampered): %w",
err,
)
}
r.buffer = append(r.buffer, decrypted...)
r.chunkIndex++
return nil
}
func (r *DecryptionReader) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, r.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], r.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,147 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type EncryptionWriter struct {
baseWriter io.Writer
cipher cipher.AEAD
buffer []byte
nonce []byte
salt []byte
chunkIndex uint64
headerWritten bool
}
func NewEncryptionWriter(
baseWriter io.Writer,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*EncryptionWriter, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
writer := &EncryptionWriter{
baseWriter: baseWriter,
cipher: aesgcm,
buffer: make([]byte, 0, ChunkSize),
nonce: nonce,
chunkIndex: 0,
headerWritten: false,
salt: salt, // Store salt for lazy header writing
}
return writer, nil
}
func (w *EncryptionWriter) Write(p []byte) (n int, err error) {
// Write header on first write (lazy initialization)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return 0, fmt.Errorf("failed to write header: %w", err)
}
}
n = len(p)
w.buffer = append(w.buffer, p...)
for len(w.buffer) >= ChunkSize {
chunk := w.buffer[:ChunkSize]
if err := w.encryptAndWriteChunk(chunk); err != nil {
return 0, err
}
w.buffer = w.buffer[ChunkSize:]
}
return n, nil
}
func (w *EncryptionWriter) Close() error {
// Write header if it hasn't been written yet (in case Close is called without any writes)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
}
if len(w.buffer) > 0 {
if err := w.encryptAndWriteChunk(w.buffer); err != nil {
return err
}
w.buffer = nil
}
return nil
}
func (w *EncryptionWriter) writeHeader(salt, nonce []byte) error {
header := make([]byte, HeaderLen)
copy(header[0:MagicBytesLen], []byte(MagicBytes))
copy(header[MagicBytesLen:MagicBytesLen+SaltLen], salt)
copy(header[MagicBytesLen+SaltLen:MagicBytesLen+SaltLen+NonceLen], nonce)
_, err := w.baseWriter.Write(header)
if err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
w.headerWritten = true
return nil
}
func (w *EncryptionWriter) encryptAndWriteChunk(chunk []byte) error {
chunkNonce := w.generateChunkNonce()
encrypted := w.cipher.Seal(nil, chunkNonce, chunk, nil)
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(len(encrypted)))
if _, err := w.baseWriter.Write(lengthBuf); err != nil {
return fmt.Errorf("failed to write chunk length: %w", err)
}
if _, err := w.baseWriter.Write(encrypted); err != nil {
return fmt.Errorf("failed to write encrypted chunk: %w", err)
}
w.chunkIndex++
return nil
}
func (w *EncryptionWriter) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, w.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], w.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,387 @@
package encryption
import (
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_EncryptDecryptRoundTrip_ReturnsOriginalData(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte(
"This is a test backup data that should be encrypted and then decrypted successfully.",
)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted := make([]byte, len(originalData))
n, err = io.ReadFull(reader, decrypted)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptDecryptRoundTrip_LargeData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := make([]byte, 100*1024)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptionWriter_MultipleWrites_CombinesCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
part1 := []byte("First part of data. ")
part2 := []byte("Second part of data. ")
part3 := []byte("Third part of data.")
expectedData := append(append(part1, part2...), part3...)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(part1)
require.NoError(t, err)
_, err = writer.Write(part2)
require.NoError(t, err)
_, err = writer.Write(part3)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, expectedData, decrypted)
}
func Test_DecryptionReader_InvalidHeader_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
invalidHeader := make([]byte, HeaderLen)
copy(invalidHeader, []byte("INVALID!"))
invalidData := bytes.NewBuffer(invalidHeader)
_, err = NewDecryptionReader(invalidData, masterKey, backupID, salt, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid magic bytes")
}
func Test_DecryptionReader_TamperedData_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This data will be tampered with.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
encryptedBytes := encrypted.Bytes()
if len(encryptedBytes) > HeaderLen+10 {
encryptedBytes[HeaderLen+10] ^= 0xFF
}
tamperedBuffer := bytes.NewBuffer(encryptedBytes)
reader, err := NewDecryptionReader(tamperedBuffer, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_DeriveBackupKey_SameInputs_ReturnsSameKey(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
assert.Equal(t, key1, key2)
}
func Test_DeriveBackupKey_DifferentInputs_ReturnsDifferentKeys(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID1 := uuid.New()
backupID2 := uuid.New()
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey1, backupID1, salt1)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey2, backupID1, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key2)
key3, err := DeriveBackupKey(masterKey1, backupID2, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key3)
key4, err := DeriveBackupKey(masterKey1, backupID1, salt2)
require.NoError(t, err)
assert.NotEqual(t, key1, key4)
}
func Test_EncryptionWriter_PartialChunk_HandledCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
smallData := []byte("Small data less than chunk size")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(smallData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, smallData, decrypted)
}
func Test_GenerateSalt_ReturnsCorrectLength(t *testing.T) {
salt, err := GenerateSalt()
require.NoError(t, err)
assert.Equal(t, SaltLen, len(salt))
}
func Test_GenerateSalt_GeneratesUniqueSalts(t *testing.T) {
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
assert.NotEqual(t, salt1, salt2)
}
func Test_GenerateNonce_ReturnsCorrectLength(t *testing.T) {
nonce, err := GenerateNonce()
require.NoError(t, err)
assert.Equal(t, NonceLen, len(nonce))
}
func Test_GenerateNonce_GeneratesUniqueNonces(t *testing.T) {
nonce1, err := GenerateNonce()
require.NoError(t, err)
nonce2, err := GenerateNonce()
require.NoError(t, err)
assert.NotEqual(t, nonce1, nonce2)
}
func Test_DecryptionReader_WrongMasterKey_ReturnsError(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("Secret data")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey1, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey2, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_EncryptionWriter_EmptyData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, 0, len(decrypted))
}
func Test_EncryptionWriter_MultipleChunks_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
dataSize := ChunkSize*3 + 1000
originalData := make([]byte, dataSize)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_DecryptionReader_SmallReads_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This is test data that will be read in small chunks.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
var decrypted []byte
buf := make([]byte, 5)
for {
n, err := reader.Read(buf)
if n > 0 {
decrypted = append(decrypted, buf[:n]...)
}
if err == io.EOF {
break
}
require.NoError(t, err)
}
assert.Equal(t, originalData, decrypted)
}

View File

@@ -0,0 +1,52 @@
package encryption
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"github.com/google/uuid"
"golang.org/x/crypto/pbkdf2"
)
const (
MagicBytes = "PGRSUS01"
MagicBytesLen = 8
SaltLen = 32
NonceLen = 12
ReservedLen = 12
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
ChunkSize = 1 * 1024 * 1024
PBKDF2Iterations = 100000
)
func DeriveBackupKey(masterKey string, backupID uuid.UUID, salt []byte) ([]byte, error) {
if masterKey == "" {
return nil, fmt.Errorf("master key cannot be empty")
}
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes", SaltLen)
}
keyMaterial := []byte(masterKey + backupID.String())
derivedKey := pbkdf2.Key(keyMaterial, salt, PBKDF2Iterations, 32, sha256.New)
return derivedKey, nil
}
func GenerateSalt() ([]byte, error) {
salt := make([]byte, SaltLen)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("failed to generate salt: %w", err)
}
return salt, nil
}
func GenerateNonce() ([]byte, error) {
nonce := make([]byte, NonceLen)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
return nonce, nil
}

View File

@@ -3,6 +3,7 @@ package backups
import (
"context"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
@@ -29,7 +30,7 @@ type CreateBackupUsecase interface {
backupProgressListener func(
completedMBs float64,
),
) error
) (*usecases_postgresql.BackupMetadata, error)
}
type BackupRemoveListener interface {

View File

@@ -1,8 +1,7 @@
package backups
import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/storages"
backups_config "postgresus-backend/internal/features/backups/config"
"time"
"github.com/google/uuid"
@@ -11,11 +10,8 @@ import (
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Database *databases.Database `json:"database" gorm:"foreignKey:DatabaseID"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
@@ -24,5 +20,9 @@ type Backup struct {
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}

View File

@@ -13,18 +13,20 @@ import (
type BackupRepository struct{}
func (r *BackupRepository) Save(backup *Backup) error {
if backup.DatabaseID == uuid.Nil || backup.StorageID == uuid.Nil {
return errors.New("database ID and storage ID are required")
}
db := storage.GetDb()
isNew := backup.ID == uuid.Nil
if isNew {
backup.ID = uuid.New()
return db.Create(backup).
Omit("Database", "Storage").
Error
}
return db.Save(backup).
Omit("Database", "Storage").
Error
}
@@ -33,8 +35,6 @@ func (r *BackupRepository) FindByDatabaseID(databaseID uuid.UUID) ([]*Backup, er
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -56,8 +56,6 @@ func (r *BackupRepository) FindByDatabaseIDWithLimit(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).
@@ -73,8 +71,6 @@ func (r *BackupRepository) FindByStorageID(storageID uuid.UUID) ([]*Backup, erro
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ?", storageID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -89,8 +85,6 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&backup).Error; err != nil {
@@ -109,8 +103,6 @@ func (r *BackupRepository) FindByID(id uuid.UUID) (*Backup, error) {
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("id = ?", id).
First(&backup).Error; err != nil {
return nil, err
@@ -124,8 +116,6 @@ func (r *BackupRepository) FindByStatus(status BackupStatus) ([]*Backup, error)
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("status = ?", status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -143,8 +133,6 @@ func (r *BackupRepository) FindByStorageIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ? AND status = ?", storageID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -162,8 +150,6 @@ func (r *BackupRepository) FindByDatabaseIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND status = ?", databaseID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -185,8 +171,6 @@ func (r *BackupRepository) FindBackupsBeforeDate(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND created_at < ?", databaseID, date).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -204,8 +188,6 @@ func (r *BackupRepository) FindByDatabaseIDWithPagination(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).

View File

@@ -2,20 +2,25 @@ package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"slices"
"strings"
"time"
util_encryption "postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -27,6 +32,8 @@ type BackupService struct {
notifierService *notifiers.NotifierService
notificationSender NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
@@ -34,9 +41,9 @@ type BackupService struct {
backupRemoveListeners []BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextMgr *BackupContextManager
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
@@ -149,11 +156,16 @@ func (s *BackupService) DeleteBackup(
return err
}
if backup.Database.WorkspaceID == nil {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot delete backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*backup.Database.WorkspaceID, user)
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
@@ -168,11 +180,11 @@ func (s *BackupService) DeleteBackup(
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
backup.Database.Name,
database.Name,
backupID.String(),
),
&user.ID,
backup.Database.WorkspaceID,
database.WorkspaceID,
)
return s.deleteBackup(backup)
@@ -220,10 +232,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup := &Backup{
DatabaseID: databaseID,
Database: database,
StorageID: storage.ID,
Storage: storage,
StorageID: storage.ID,
Status: BackupStatusInProgress,
@@ -251,10 +260,10 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextMgr.RegisterBackup(backup.ID, cancel)
defer s.backupContextMgr.UnregisterBackup(backup.ID)
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
err = s.createBackupUseCase.Execute(
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
@@ -266,7 +275,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
errors.Is(err, context.Canceled)
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
@@ -278,7 +292,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
@@ -324,6 +338,13 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
@@ -364,6 +385,11 @@ func (s *BackupService) SendBackupNotification(
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
@@ -375,9 +401,17 @@ func (s *BackupService) SendBackupNotification(
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf("❌ Backup failed for database \"%s\"", database.Name)
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf("✅ Backup completed for database \"%s\"", database.Name)
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
@@ -427,11 +461,16 @@ func (s *BackupService) CancelBackup(
return err
}
if backup.Database.WorkspaceID == nil {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*backup.Database.WorkspaceID, user)
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
@@ -443,18 +482,18 @@ func (s *BackupService) CancelBackup(
return errors.New("backup is not in progress")
}
if err := s.backupContextMgr.CancelBackup(backupID); err != nil {
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
backup.Database.Name,
database.Name,
backupID.String(),
),
&user.ID,
backup.Database.WorkspaceID,
database.WorkspaceID,
)
return nil
@@ -469,12 +508,17 @@ func (s *BackupService) GetBackupFile(
return nil, err
}
if backup.Database.WorkspaceID == nil {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*backup.Database.WorkspaceID,
*database.WorkspaceID,
user,
)
if err != nil {
@@ -484,22 +528,17 @@ func (s *BackupService) GetBackupFile(
return nil, errors.New("insufficient permissions to download backup for this database")
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return nil, err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
backup.Database.Name,
database.Name,
backupID.String(),
),
&user.ID,
backup.Database.WorkspaceID,
database.WorkspaceID,
)
return storage.GetFile(backup.ID)
return s.getBackupReader(backupID)
}
func (s *BackupService) deleteBackup(backup *Backup) error {
@@ -514,9 +553,12 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
return err
}
err = storage.DeleteFile(backup.ID)
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
return err
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
@@ -551,3 +593,91 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
return nil
}
// GetBackupReader returns a reader for the backup file
// If encrypted, wraps with DecryptionReader
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("failed to find backup: %w", err)
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
// If not encrypted, return raw reader
if backup.Encryption == backups_config.BackupEncryptionNone {
s.logger.Info("Returning non-encrypted backup", "backupId", backupID)
return fileReader, nil
}
// Decrypt on-the-fly for encrypted backups
if backup.Encryption != backups_config.BackupEncryptionEncrypted {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("unsupported encryption type: %s", backup.Encryption)
}
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("backup marked as encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to get master key: %w", err)
}
// Decode salt and IV
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode IV: %w", err)
}
// Wrap with decrypting reader
decryptionReader, err := encryption.NewDecryptionReader(
fileReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to create decrypting reader: %w", err)
}
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
}, nil
}

View File

@@ -3,17 +3,22 @@ package backups
import (
"context"
"errors"
"strings"
"testing"
"time"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -52,11 +57,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
@@ -98,11 +105,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
@@ -121,11 +130,13 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
@@ -170,9 +181,9 @@ func (uc *CreateFailedBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return errors.New("backup failed")
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct {
@@ -187,7 +198,11 @@ func (uc *CreateSuccessBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return nil
return &usecases_postgresql.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}

View File

@@ -15,7 +15,7 @@ type CreateBackupUsecase struct {
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
}
// Execute creates a backup of the database and returns the backup size in MB
// Execute creates a backup of the database and returns the backup metadata
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
@@ -25,7 +25,7 @@ func (uc *CreateBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
if database.Type == databases.DatabaseTypePostgres {
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
@@ -37,5 +37,5 @@ func (uc *CreateBackupUsecase) Execute(
)
}
return errors.New("database type not supported")
return nil, errors.New("database type not supported")
}

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -14,17 +15,39 @@ import (
"time"
"postgresus-backend/internal/config"
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"github.com/google/uuid"
)
const (
backupTimeout = 23 * time.Hour
shutdownCheckInterval = 1 * time.Second
copyBufferSize = 8 * 1024 * 1024
progressReportIntervalMB = 1.0
pgConnectTimeout = 30
compressionLevel = 5
exitCodeAccessViolation = -1073741819
exitCodeGenericError = 1
exitCodeConnectionError = 2
)
type CreatePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor encryption.FieldEncryptor
}
type writeResult struct {
bytesWritten int
writeErr error
}
// Execute creates a backup of the database
@@ -37,7 +60,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*BackupMetadata, error) {
uc.logger.Info(
"Creating PostgreSQL backup via pg_dump custom format",
"databaseId",
@@ -47,38 +70,24 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
)
if !backupConfig.IsBackupsEnabled {
return fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
pg := db.Postgresql
if pg == nil {
return fmt.Errorf("postgresql database configuration is required for pg_dump backups")
return nil, fmt.Errorf("postgresql database configuration is required for pg_dump backups")
}
if pg.Database == nil || *pg.Database == "" {
return fmt.Errorf("database name is required for pg_dump backups")
return nil, fmt.Errorf("database name is required for pg_dump backups")
}
args := []string{
"-Fc", // custom format with built-in compression
"--no-password", // Use environment variable for password, prevent prompts
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose", // Add verbose output to help with debugging
}
args := uc.buildPgDumpArgs(pg)
// Use zstd compression level 5 for PostgreSQL 15+ (better compression and speed)
// Fall back to gzip compression level 5 for older versions
if pg.Version == tools.PostgresqlVersion13 || pg.Version == tools.PostgresqlVersion14 ||
pg.Version == tools.PostgresqlVersion15 {
args = append(args, "-Z", "5")
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
} else {
args = append(args, "--compress=zstd:5")
uc.logger.Info("Using zstd compression level 5", "version", pg.Version)
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, pg.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
}
return uc.streamToStorage(
@@ -92,7 +101,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
config.GetEnv().PostgresesInstallDir,
),
args,
pg.Password,
decryptedPassword,
storage,
db,
backupProgressListener,
@@ -110,124 +119,38 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storage *storages.Storage,
db *databases.Database,
backupProgressListener func(completedMBs float64),
) error {
) (*BackupMetadata, error) {
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
// if backup not fit into 23 hours, Postgresus
// seems not to work for such database size
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
ctx, cancel := uc.createBackupContext(parentCtx)
defer cancel()
// Monitor for shutdown and cancel context if needed
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
// Create temporary .pgpass file as a more reliable alternative to PGPASSWORD
pgpassFile, err := uc.createTempPgpassFile(db.Postgresql, password)
pgpassFile, err := uc.setupPgpassFile(db.Postgresql, password)
if err != nil {
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
return nil, err
}
defer func() {
if pgpassFile != "" {
_ = os.Remove(pgpassFile)
// Remove the entire temp directory (which contains the .pgpass file)
_ = os.RemoveAll(filepath.Dir(pgpassFile))
}
}()
// Verify .pgpass file was created successfully
if pgpassFile == "" {
return fmt.Errorf("temporary .pgpass file was not created")
}
// Verify .pgpass file was created correctly
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return fmt.Errorf("failed to verify .pgpass file: %w", err)
}
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
// Start with system environment variables to preserve Windows PATH, SystemRoot, etc.
cmd.Env = os.Environ()
// Use the .pgpass file for authentication
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
// Debug password setup (without exposing the actual password)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", backupConfig.CpuCount,
)
// Add PostgreSQL-specific environment variables
cmd.Env = append(cmd.Env, "PGCLIENTENCODING=UTF8")
cmd.Env = append(cmd.Env, "PGCONNECT_TIMEOUT=30")
// Add encoding-related environment variables to handle character encoding issues
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
// Add PostgreSQL-specific encoding settings
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
shouldRequireSSL := db.Postgresql.IsHttps
// Require SSL when explicitly configured
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", db.Postgresql.IsHttps)
} else {
// SSL not explicitly required, but prefer it if available
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", db.Postgresql.IsHttps)
}
// Set other SSL parameters to avoid certificate issues
cmd.Env = append(cmd.Env, "PGSSLCERT=") // No client certificate
cmd.Env = append(cmd.Env, "PGSSLKEY=") // No client key
cmd.Env = append(cmd.Env, "PGSSLROOTCERT=") // No root certificate verification
cmd.Env = append(cmd.Env, "PGSSLCRL=") // No certificate revocation list
// Verify executable exists and is accessible
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf(
"PostgreSQL executable not found or not accessible: %s - %w",
pgBin,
err,
)
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, backupConfig.CpuCount, pgBin); err != nil {
return nil, err
}
pgStdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe: %w", err)
return nil, fmt.Errorf("stdout pipe: %w", err)
}
pgStderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe: %w", err)
return nil, fmt.Errorf("stderr pipe: %w", err)
}
// Capture stderr in a separate goroutine to ensure we don't miss any error output
@@ -237,23 +160,31 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
stderrCh <- stderrOutput
}()
// A pipe connecting pg_dump output → storage
storageReader, storageWriter := io.Pipe()
// Create a counting writer to track bytes
countingWriter := &CountingWriter{writer: storageWriter}
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backupConfig,
storageWriter,
)
if err != nil {
return nil, err
}
countingWriter := &CountingWriter{writer: finalWriter}
// The backup ID becomes the object key / filename in storage
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErrCh <- storage.SaveFile(uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErrCh <- saveErr
}()
// Start pg_dump
if err = cmd.Start(); err != nil {
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
return nil, fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
}
// Copy pg output directly to storage with shutdown checks
@@ -270,37 +201,22 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
copyResultCh <- err
}()
// Wait for the copy to finish first, then the dump process
copyErr := <-copyResultCh
bytesWritten := <-bytesWrittenCh
waitErr := cmd.Wait()
// Check for shutdown or cancellation before finalizing
select {
case <-ctx.Done():
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
}
<-saveErrCh // Wait for storage to finish
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
return nil, uc.checkCancellationReason()
default:
}
// Close the pipe writer to signal end of data
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
<-saveErrCh
return nil, err
}
// Wait until storage ends reading
saveErr := <-saveErrCh
stderrOutput := <-stderrCh
@@ -312,149 +228,34 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
switch {
case waitErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
// Enhanced error handling for PostgreSQL connection and SSL issues
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf(
"%s failed: %v stderr: %s",
filepath.Base(pgBin),
waitErr,
stderrStr,
)
// Check for specific PostgreSQL error patterns
if exitErr, ok := waitErr.(*exec.ExitError); ok {
exitCode := exitErr.ExitCode()
// Enhanced debugging for exit status 1 with empty stderr
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=30",
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
},
)
errorMsg = fmt.Sprintf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"args", args,
"exitCode", fmt.Sprintf("0x%X", uint32(exitCode)),
)
errorMsg = fmt.Sprintf(
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
} else if exitCode == 1 || exitCode == 2 {
// Check for common connection and authentication issues
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). The server may not allow connections from your IP address or may require different authentication settings. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. stderr: %s",
len(password),
password == "",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
errorMsg = fmt.Sprintf(
"PostgreSQL SSL connection failed. The server may require SSL encryption or have SSL configuration issues. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection refused. Check if the server is running and accessible from your network. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "timeout") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
}
}
return errors.New(errorMsg)
return nil, uc.buildPgDumpErrorMessage(waitErr, stderrOutput, pgBin, args, password)
case copyErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("copy to storage: %w", copyErr)
return nil, fmt.Errorf("copy to storage: %w", copyErr)
case saveErr != nil:
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("save to storage: %w", saveErr)
return nil, fmt.Errorf("save to storage: %w", saveErr)
}
return nil
return &backupMetadata, nil
}
// copyWithShutdownCheck copies data from src to dst while checking for shutdown
func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
backupProgressListener func(completedMBs float64),
) (int64, error) {
buf := make([]byte, 32*1024) // 32KB buffer
buf := make([]byte, copyBufferSize)
var totalBytesWritten int64
// Progress reporting interval - report every 1MB of data
var lastReportedMB float64
const reportIntervalMB = 1.0
for {
select {
@@ -469,7 +270,23 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
bytesRead, readErr := src.Read(buf)
if bytesRead > 0 {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
writeResultCh := make(chan writeResult, 1)
go func() {
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
writeResultCh <- writeResult{bytesWritten, writeErr}
}()
var bytesWritten int
var writeErr error
select {
case <-ctx.Done():
return totalBytesWritten, fmt.Errorf("copy cancelled during write: %w", ctx.Err())
case result := <-writeResultCh:
bytesWritten = result.bytesWritten
writeErr = result.writeErr
}
if bytesWritten < 0 || bytesRead < bytesWritten {
bytesWritten = 0
if writeErr == nil {
@@ -487,12 +304,9 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
totalBytesWritten += int64(bytesWritten)
// Report progress based on total size
if backupProgressListener != nil {
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
// Only report if we've written at least 1MB more data than last report
if currentSizeMB >= lastReportedMB+reportIntervalMB {
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
backupProgressListener(currentSizeMB)
lastReportedMB = currentSizeMB
}
@@ -503,7 +317,6 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
@@ -511,12 +324,417 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
return totalBytesWritten, nil
}
// containsIgnoreCase checks if a string contains a substring, ignoring case
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlDatabase) []string {
args := []string{
"-Fc",
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose",
}
for _, schema := range pg.IncludeSchemas {
args = append(args, "-n", schema)
}
compressionArgs := uc.getCompressionArgs(pg.Version)
return append(args, compressionArgs...)
}
func (uc *CreatePostgresqlBackupUsecase) getCompressionArgs(
version tools.PostgresqlVersion,
) []string {
if uc.isOlderPostgresVersion(version) {
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", version)
return []string{"-Z", strconv.Itoa(compressionLevel)}
}
uc.logger.Info("Using zstd compression level 5", "version", version)
return []string{fmt.Sprintf("--compress=zstd:%d", compressionLevel)}
}
func (uc *CreatePostgresqlBackupUsecase) isOlderPostgresVersion(
version tools.PostgresqlVersion,
) bool {
return version == tools.PostgresqlVersion12 ||
version == tools.PostgresqlVersion13 ||
version == tools.PostgresqlVersion14 ||
version == tools.PostgresqlVersion15
}
func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
parentCtx context.Context,
) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
go func() {
ticker := time.NewTicker(shutdownCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-parentCtx.Done():
cancel()
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
return ctx, cancel
}
func (uc *CreatePostgresqlBackupUsecase) setupPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
) (string, error) {
pgpassFile, err := uc.createTempPgpassFile(pgConfig, password)
if err != nil {
return "", fmt.Errorf("failed to create temporary .pgpass file: %w", err)
}
if pgpassFile == "" {
return "", fmt.Errorf("temporary .pgpass file was not created")
}
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return "", fmt.Errorf("failed to verify .pgpass file: %w", err)
}
return pgpassFile, nil
}
func (uc *CreatePostgresqlBackupUsecase) setupPgEnvironment(
cmd *exec.Cmd,
pgpassFile string,
shouldRequireSSL bool,
password string,
cpuCount int,
pgBin string,
) error {
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", cpuCount,
)
cmd.Env = append(cmd.Env,
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT="+strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
)
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", shouldRequireSSL)
} else {
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", shouldRequireSSL)
}
cmd.Env = append(cmd.Env,
"PGSSLCERT=",
"PGSSLKEY=",
"PGSSLROOTCERT=",
"PGSSLCRL=",
)
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf("PostgreSQL executable not found or not accessible: %s - %w", pgBin, err)
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
metadata := BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
saveErrCh chan error,
) {
if encryptionWriter != nil {
go func() {
if closeErr := encryptionWriter.Close(); closeErr != nil {
uc.logger.Error(
"Failed to close encrypting writer during cancellation",
"error",
closeErr,
)
}
}()
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
}
<-saveErrCh
}
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
) error {
encryptionCloseErrCh := make(chan error, 1)
if encryptionWriter != nil {
go func() {
closeErr := encryptionWriter.Close()
if closeErr != nil {
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
}
encryptionCloseErrCh <- closeErr
}()
} else {
encryptionCloseErrCh <- nil
}
encryptionCloseErr := <-encryptionCloseErrCh
if encryptionCloseErr != nil {
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
}
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer", "error", err)
return err
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellation(ctx context.Context) error {
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
return nil
}
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellationReason() error {
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
}
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
password string,
) error {
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf("%s failed: %v stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
return errors.New(errorMsg)
}
exitCode := exitErr.ExitCode()
if exitCode == exitCodeGenericError && strings.TrimSpace(stderrStr) == "" {
return uc.handleExitCode1NoStderr(pgBin, args)
}
if exitCode == exitCodeAccessViolation {
return uc.handleAccessViolation(pgBin, stderrStr)
}
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
return uc.handleConnectionErrors(stderrStr, password)
}
return errors.New(errorMsg)
}
func (uc *CreatePostgresqlBackupUsecase) handleExitCode1NoStderr(
pgBin string,
args []string,
) error {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=" + strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
},
)
return fmt.Errorf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
}
func (uc *CreatePostgresqlBackupUsecase) handleAccessViolation(
pgBin string,
stderrStr string,
) error {
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"exitCode", "0xC0000005",
)
return fmt.Errorf(
"%s crashed with access violation (0xC0000005). "+
"This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. "+
"stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
}
func (uc *CreatePostgresqlBackupUsecase) handleConnectionErrors(
stderrStr string,
password string,
) error {
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
return fmt.Errorf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). "+
"The server may not allow connections from your IP address or may require different authentication settings. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "no password supplied") ||
containsIgnoreCase(stderrStr, "fe_sendauth") {
return fmt.Errorf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. "+
"stderr: %s",
len(password),
password == "",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
return fmt.Errorf(
"PostgreSQL SSL connection failed. "+
"The server may require SSL encryption or have SSL configuration issues. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
return fmt.Errorf(
"PostgreSQL connection refused. "+
"Check if the server is running and accessible from your network. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "authentication") ||
containsIgnoreCase(stderrStr, "password") {
return fmt.Errorf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "timeout") {
return fmt.Errorf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
return fmt.Errorf("PostgreSQL connection or authentication error. stderr: %s", stderrStr)
}
// createTempPgpassFile creates a temporary .pgpass file with the given password
func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
@@ -525,14 +743,17 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
return "", nil
}
escapedHost := tools.EscapePgpassField(pgConfig.Host)
escapedUsername := tools.EscapePgpassField(pgConfig.Username)
escapedPassword := tools.EscapePgpassField(password)
pgpassContent := fmt.Sprintf("%s:%d:*:%s:%s",
pgConfig.Host,
escapedHost,
pgConfig.Port,
pgConfig.Username,
password,
escapedUsername,
escapedPassword,
)
// it always create unique directory like /tmp/pgpass-1234567890
tempDir, err := os.MkdirTemp("", "pgpass")
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
@@ -546,3 +767,7 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
return pgpassFile, nil
}
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
}

View File

@@ -1,11 +1,15 @@
package usecases_postgresql
import (
"postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
logger.GetLogger(),
secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
}
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {

View File

@@ -0,0 +1,15 @@
package usecases_postgresql
import backups_config "postgresus-backend/internal/features/backups/config"
type EncryptionMetadata struct {
Salt string
IV string
Encryption backups_config.BackupEncryption
}
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
}

View File

@@ -20,15 +20,15 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
// SaveBackupConfig
// @Summary Save backup configuration
// @Description Save or update backup configuration for a database
// @Description Save or update backup configuration for a database. Encryption can be set to NONE (no encryption) or ENCRYPTED (AES-256-GCM encryption).
// @Tags backup-configs
// @Accept json
// @Produce json
// @Param request body BackupConfig true "Backup configuration data"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 500
// @Param request body BackupConfig true "Backup configuration data (encryption field: NONE or ENCRYPTED)"
// @Success 200 {object} BackupConfig "Returns the saved backup configuration including encryption settings"
// @Failure 400 {object} map[string]string "Invalid encryption value or other validation errors"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 500 {object} map[string]string "Internal server error"
// @Router /backup-configs/save [post]
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -57,14 +57,14 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
// GetBackupConfigByDbID
// @Summary Get backup configuration by database ID
// @Description Get backup configuration for a specific database
// @Description Get backup configuration for a specific database including encryption settings (NONE or ENCRYPTED)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 404
// @Success 200 {object} BackupConfig "Returns backup configuration with encryption field"
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Backup configuration not found"
// @Router /backup-configs/database/{id} [get]
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -94,7 +94,6 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
// @Tags backup-configs
// @Produce json
// @Param id path string true "Storage ID"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {object} map[string]bool
// @Failure 400
// @Failure 401
@@ -113,19 +112,7 @@ func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
return
}
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
isUsing, err := c.backupConfigService.IsStorageUsing(user, workspaceID, id)
isUsing, err := c.backupConfigService.IsStorageUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return

View File

@@ -341,7 +341,7 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using?workspace_id="+workspace.ID.String(),
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
@@ -354,7 +354,7 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using?workspace_id="+workspace.ID.String(),
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
@@ -368,6 +368,86 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
}
}
func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionNone, response.Encryption)
}
func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionEncrypted,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionEncrypted, response.Encryption)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,

View File

@@ -6,3 +6,10 @@ const (
NotificationBackupFailed BackupNotificationType = "BACKUP_FAILED"
NotificationBackupSuccess BackupNotificationType = "BACKUP_SUCCESS"
)
type BackupEncryption string
const (
BackupEncryptionNone BackupEncryption = "NONE"
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
)

View File

@@ -31,6 +31,8 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
CpuCount int `json:"cpuCount" gorm:"type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
}
func (h *BackupConfig) TableName() string {
@@ -88,6 +90,11 @@ func (b *BackupConfig) Validate() error {
return errors.New("max failed tries count must be greater than 0")
}
if b.Encryption != "" && b.Encryption != BackupEncryptionNone &&
b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption must be NONE or ENCRYPTED")
}
return nil
}
@@ -103,5 +110,6 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
CpuCount: b.CpuCount,
Encryption: b.Encryption,
}
}

View File

@@ -119,7 +119,6 @@ func (s *BackupConfigService) GetBackupConfigByDbId(
func (s *BackupConfigService) IsStorageUsing(
user *users_models.User,
workspaceID uuid.UUID,
storageID uuid.UUID,
) (bool, error) {
_, err := s.storageService.GetStorage(user, storageID)
@@ -172,6 +171,7 @@ func (s *BackupConfigService) initializeDefaultConfig(
CpuCount: 1,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
})
return err

View File

@@ -26,7 +26,8 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/databases/test-connection-direct", c.TestDatabaseConnectionDirect)
router.POST("/databases/:id/copy", c.CopyDatabase)
router.GET("/databases/notifier/:id/is-using", c.IsNotifierUsing)
router.POST("/databases/is-readonly", c.IsUserReadOnly)
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
}
// CreateDatabase
@@ -271,7 +272,6 @@ func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
// @Tags databases
// @Produce json
// @Param id path string true "Notifier ID"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {object} map[string]bool
// @Failure 400
// @Failure 401
@@ -290,19 +290,7 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
return
}
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
isUsing, err := c.databaseService.IsNotifierUsing(user, workspaceID, id)
isUsing, err := c.databaseService.IsNotifierUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -343,3 +331,76 @@ func (c *DatabaseController) CopyDatabase(ctx *gin.Context) {
ctx.JSON(http.StatusCreated, copiedDatabase)
}
// IsUserReadOnly
// @Summary Check if database user is read-only
// @Description Check if current database credentials have only read (SELECT) privileges
// @Tags databases
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body Database true "Database configuration to check"
// @Success 200 {object} IsReadOnlyResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /databases/is-readonly [post]
func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
}
// CreateReadOnlyUser
// @Summary Create read-only database user
// @Description Create a new PostgreSQL user with read-only privileges for backup operations
// @Tags databases
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param request body Database true "Database configuration to create user for"
// @Success 200 {object} CreateReadOnlyUserResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /databases/create-readonly-user [post]
func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
username, password, err := c.databaseService.CreateReadOnlyUser(user, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, CreateReadOnlyUserResponse{
Username: username,
Password: password,
})
}

View File

@@ -16,6 +16,7 @@ import (
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -769,6 +770,71 @@ func createTestDatabaseViaAPI(
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
testDbName := "test_db"
plainPassword := "my-super-secret-password-123"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: plainPassword,
Database: &testDbName,
},
}
var createdDatabase Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
request,
http.StatusCreated,
&createdDatabase,
)
repository := &DatabaseRepository{}
databaseFromDB, err := repository.FindByID(createdDatabase.ID)
assert.NoError(t, err)
assert.NotNil(t, databaseFromDB)
assert.NotNil(t, databaseFromDB.Postgresql)
assert.True(
t,
strings.HasPrefix(databaseFromDB.Postgresql.Password, "enc:"),
"Password should be encrypted in database with 'enc:' prefix, got: %s",
databaseFromDB.Postgresql.Password,
)
encryptor := encryption.GetFieldEncryptor()
decryptedPassword, err := encryptor.Decrypt(
databaseFromDB.ID,
databaseFromDB.Postgresql.Password,
)
assert.NoError(t, err)
assert.Equal(t, plainPassword, decryptedPassword,
"Decrypted password should match original plaintext password")
test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+createdDatabase.ID.String(),
"Bearer "+owner.Token,
http.StatusNoContent,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -815,7 +881,15 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
assert.Equal(t, "original-password-secret", database.Postgresql.Password)
// Verify password is encrypted
assert.True(t, strings.HasPrefix(database.Postgresql.Password, "enc:"),
"Password should be encrypted in database")
// Verify it can be decrypted back to original
encryptor := encryption.GetFieldEncryptor()
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)

View File

@@ -5,20 +5,21 @@ import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"regexp"
"slices"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"gorm.io/gorm"
)
type PostgresqlDatabase struct {
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
RestoreID *uuid.UUID `json:"restoreId" gorm:"type:uuid;column:restore_id"`
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
@@ -29,17 +30,40 @@ type PostgresqlDatabase struct {
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
// backup settings
IncludeSchemas []string `json:"includeSchemas" gorm:"-"`
IncludeSchemasString string `json:"-" gorm:"column:include_schemas;type:text;not null;default:''"`
// restore settings (not saved to DB)
IsExcludeExtensions bool `json:"isExcludeExtensions" gorm:"-"`
}
func (p *PostgresqlDatabase) TableName() string {
return "postgresql_databases"
}
func (p *PostgresqlDatabase) Validate() error {
if p.Version == "" {
return errors.New("version is required")
func (p *PostgresqlDatabase) BeforeSave(_ *gorm.DB) error {
if len(p.IncludeSchemas) > 0 {
p.IncludeSchemasString = strings.Join(p.IncludeSchemas, ",")
} else {
p.IncludeSchemasString = ""
}
return nil
}
func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error {
if p.IncludeSchemasString != "" {
p.IncludeSchemas = strings.Split(p.IncludeSchemasString, ",")
} else {
p.IncludeSchemas = []string{}
}
return nil
}
func (p *PostgresqlDatabase) Validate() error {
if p.Host == "" {
return errors.New("host is required")
}
@@ -59,14 +83,22 @@ func (p *PostgresqlDatabase) Validate() error {
return nil
}
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
func (p *PostgresqlDatabase) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return testSingleDatabaseConnection(logger, ctx, p)
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
}
func (p *PostgresqlDatabase) HideSensitiveData() {
if p == nil {
return
}
p.Password = ""
}
@@ -77,25 +109,491 @@ func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
p.Username = incoming.Username
p.Database = incoming.Database
p.IsHttps = incoming.IsHttps
p.IncludeSchemas = incoming.IncludeSchemas
if incoming.Password != "" {
p.Password = incoming.Password
}
}
func (p *PostgresqlDatabase) EncryptSensitiveFields(
databaseID uuid.UUID,
encryptor encryption.FieldEncryptor,
) error {
if p.Password != "" {
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
if err != nil {
return err
}
p.Password = encrypted
}
return nil
}
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
// This should be called before encrypting sensitive fields.
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if p.Version != "" {
return nil
}
if p.Database == nil || *p.Database == "" {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectDatabaseVersion(ctx, conn)
if err != nil {
return err
}
p.Version = detectedVersion
return nil
}
// IsUserReadOnly checks if the database user has read-only privileges.
//
// This method performs a comprehensive security check by examining:
// - Role-level attributes (superuser, createrole, createdb)
// - Database-level privileges (CREATE, TEMP)
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
//
// A user is considered read-only only if they have ZERO write privileges
// across all three levels. This ensures the database user follows the
// principle of least privilege for backup operations.
func (p *PostgresqlDatabase) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
// LEVEL 1: Check role-level attributes
var isSuperuser, canCreateRole, canCreateDB bool
err = conn.QueryRow(ctx, `
SELECT
rolsuper,
rolcreaterole,
rolcreatedb
FROM pg_roles
WHERE rolname = current_user
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
if err != nil {
return false, fmt.Errorf("failed to check role attributes: %w", err)
}
if isSuperuser || canCreateRole || canCreateDB {
return false, nil
}
// LEVEL 2: Check database-level privileges
var canCreate, canTemp bool
err = conn.QueryRow(ctx, `
SELECT
has_database_privilege(current_user, current_database(), 'CREATE') as can_create,
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
`).Scan(&canCreate, &canTemp)
if err != nil {
return false, fmt.Errorf("failed to check database privileges: %w", err)
}
if canCreate || canTemp {
return false, nil
}
// LEVEL 2.5: Check schema-level CREATE privileges
schemaRows, err := conn.Query(ctx, `
SELECT DISTINCT nspname
FROM pg_namespace n
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
`)
if err != nil {
return false, fmt.Errorf("failed to check schema privileges: %w", err)
}
defer schemaRows.Close()
// If user has CREATE privilege on any schema, they're not read-only
if schemaRows.Next() {
return false, nil
}
if err := schemaRows.Err(); err != nil {
return false, fmt.Errorf("error iterating schema privileges: %w", err)
}
// LEVEL 3: Check table-level write permissions
rows, err := conn.Query(ctx, `
SELECT DISTINCT privilege_type
FROM information_schema.role_table_grants
WHERE grantee = current_user
AND table_schema NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return false, fmt.Errorf("failed to check table privileges: %w", err)
}
defer rows.Close()
writePrivileges := map[string]bool{
"INSERT": true,
"UPDATE": true,
"DELETE": true,
"TRUNCATE": true,
"REFERENCES": true,
"TRIGGER": true,
}
for rows.Next() {
var privilege string
if err := rows.Scan(&privilege); err != nil {
return false, fmt.Errorf("failed to scan privilege: %w", err)
}
if writePrivileges[privilege] {
return false, nil
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating privileges: %w", err)
}
return true, nil
}
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
//
// This method performs the following operations atomically in a single transaction:
// 1. Creates a PostgreSQL user with a UUID-based password
// 2. Grants CONNECT privilege on the database
// 3. Grants USAGE on all non-system schemas
// 4. Grants SELECT on all existing tables and sequences
// 5. Sets default privileges for future tables and sequences
//
// Security features:
// - Username format: "postgresus-{8-char-uuid}" for uniqueness
// - Password: Full UUID (36 characters) for strong entropy
// - Transaction safety: All operations rollback on any failure
// - Retry logic: Up to 3 attempts if username collision occurs
// - Pre-validation: Checks CREATEROLE privilege before starting transaction
func (p *PostgresqlDatabase) CreateReadOnlyUser(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, string, error) {
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
if err != nil {
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
}
connStr := buildConnectionStringForDB(p, *p.Database, password)
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return "", "", fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
// Pre-validate: Check if current user can create roles
var canCreateRole, isSuperuser bool
err = conn.QueryRow(ctx, `
SELECT rolcreaterole, rolsuper
FROM pg_roles
WHERE rolname = current_user
`).Scan(&canCreateRole, &isSuperuser)
if err != nil {
return "", "", fmt.Errorf("failed to check permissions: %w", err)
}
if !canCreateRole && !isSuperuser {
return "", "", errors.New("current database user lacks CREATEROLE privilege")
}
// Retry logic for username collision
maxRetries := 3
for attempt := range maxRetries {
// Generate base username for PostgreSQL user creation
baseUsername := fmt.Sprintf("postgresus-%s", uuid.New().String()[:8])
// For Supabase session pooler, the username format for connection is "username.projectid"
// but the actual PostgreSQL user must be created with just the base name.
// The pooler will strip the ".projectid" suffix when authenticating.
connectionUsername := baseUsername
if isSupabaseConnection(p.Host, p.Username) {
if supabaseProjectID := extractSupabaseProjectID(p.Username); supabaseProjectID != "" {
connectionUsername = fmt.Sprintf("%s.%s", baseUsername, supabaseProjectID)
}
}
newPassword := uuid.New().String()
tx, err := conn.Begin(ctx)
if err != nil {
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
}
success := false
defer func() {
if !success {
if rollbackErr := tx.Rollback(ctx); rollbackErr != nil {
logger.Error("Failed to rollback transaction", "error", rollbackErr)
}
}
}()
// Step 1: Create PostgreSQL user with LOGIN privilege
// Note: We use baseUsername for the actual PostgreSQL user name if Supabase is used
_, err = tx.Exec(
ctx,
fmt.Sprintf(`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`, baseUsername, newPassword),
)
if err != nil {
if err.Error() != "" && attempt < maxRetries-1 {
continue
}
return "", "", fmt.Errorf("failed to create user: %w", err)
}
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
// public schema through the PUBLIC role. This is a one-time operation that affects
// the entire database, making it more secure by default.
// Note: This only affects the public schema; other schemas are unaffected.
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
!strings.Contains(err.Error(), "permission denied") {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
if err != nil {
logger.Error(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
)
}
// Step 2: Grant database connection privilege and revoke TEMP
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
)
if err != nil {
return "", "", fmt.Errorf("failed to grant connect privilege: %w", err)
}
// Revoke TEMP privilege from PUBLIC role (like CREATE on public schema, TEMP is granted to PUBLIC by default)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE TEMP ON DATABASE "%s" FROM PUBLIC`, *p.Database))
if err != nil {
logger.Warn("Failed to revoke TEMP from PUBLIC", "error", err)
}
// Also revoke from the specific user (belt and suspenders)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE TEMP ON DATABASE "%s" FROM "%s"`, *p.Database, baseUsername),
)
if err != nil {
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 3: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
if err != nil {
return "", "", fmt.Errorf("failed to get schemas: %w", err)
}
var schemas []string
for rows.Next() {
var schema string
if err := rows.Scan(&schema); err != nil {
rows.Close()
return "", "", fmt.Errorf("failed to scan schema: %w", err)
}
schemas = append(schemas, schema)
}
rows.Close()
if err := rows.Err(); err != nil {
return "", "", fmt.Errorf("error iterating schemas: %w", err)
}
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
for _, schema := range schemas {
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE CREATE ON SCHEMA "%s" FROM "%s"`, schema, baseUsername),
)
if err != nil {
logger.Warn(
"Failed to revoke CREATE on schema",
"error",
err,
"schema",
schema,
"username",
baseUsername,
)
}
// Grant only USAGE (not CREATE)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT USAGE ON SCHEMA "%s" TO "%s"`, schema, baseUsername),
)
if err != nil {
return "", "", fmt.Errorf("failed to grant usage on schema %s: %w", schema, err)
}
}
// Step 5: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
_, err = tx.Exec(ctx, grantSelectSQL)
if err != nil {
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
}
// Step 6: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
if err != nil {
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
}
// Step 7: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)
if err != nil {
return "", "", fmt.Errorf("failed to verify user creation: %w", err)
}
if err := tx.Commit(ctx); err != nil {
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
}
success = true
// Return connectionUsername (with project ID suffix for Supabase) for the caller to use when connecting
logger.Info(
"Read-only user created successfully",
"username",
baseUsername,
"connectionUsername",
connectionUsername,
)
return connectionUsername, newPassword, nil
}
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
func testSingleDatabaseConnection(
logger *slog.Logger,
ctx context.Context,
postgresDb *PostgresqlDatabase,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
// For single database backup, we need to connect to the specific database
if postgresDb.Database == nil || *postgresDb.Database == "" {
return errors.New("database name is required for single database backup (pg_dump)")
}
// Decrypt password if needed
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
// Build connection string for the specific database
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
// Test connection
conn, err := pgx.Connect(ctx, connStr)
@@ -112,10 +610,12 @@ func testSingleDatabaseConnection(
}
}()
// Check version after successful connection
if err := verifyDatabaseVersion(ctx, conn, postgresDb.Version); err != nil {
// Detect and set the database version automatically
detectedVersion, err := detectDatabaseVersion(ctx, conn)
if err != nil {
return err
}
postgresDb.Version = detectedVersion
// Test if we can perform basic operations (like pg_dump would need)
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
@@ -129,35 +629,31 @@ func testSingleDatabaseConnection(
return nil
}
// verifyDatabaseVersion checks if the actual database version matches the specified version
func verifyDatabaseVersion(
ctx context.Context,
conn *pgx.Conn,
expectedVersion tools.PostgresqlVersion,
) error {
// detectDatabaseVersion queries and returns the PostgreSQL major version
func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.PostgresqlVersion, error) {
var versionStr string
err := conn.QueryRow(ctx, "SELECT version()").Scan(&versionStr)
if err != nil {
return fmt.Errorf("failed to query database version: %w", err)
return "", fmt.Errorf("failed to query database version: %w", err)
}
// Parse version from string like "PostgreSQL 14.2 on x86_64-pc-linux-gnu..."
re := regexp.MustCompile(`PostgreSQL (\d+)\.`)
// or "PostgreSQL 16 maintained by Postgre BY..." (some builds omit minor version)
re := regexp.MustCompile(`PostgreSQL (\d+)`)
matches := re.FindStringSubmatch(versionStr)
if len(matches) < 2 {
return fmt.Errorf("could not parse version from: %s", versionStr)
return "", fmt.Errorf("could not parse version from: %s", versionStr)
}
actualVersion := tools.GetPostgresqlVersionEnum(matches[1])
if actualVersion != expectedVersion {
return fmt.Errorf(
"you specified wrong version. Real version is %s, but you specified %s",
actualVersion,
expectedVersion,
)
}
majorVersion := matches[1]
return nil
// Map to known PostgresqlVersion enum values
switch majorVersion {
case "12", "13", "14", "15", "16", "17", "18":
return tools.PostgresqlVersion(majorVersion), nil
default:
return "", fmt.Errorf("unsupported PostgreSQL version: %s", majorVersion)
}
}
// testBasicOperations tests basic operations that backup tools need
@@ -178,116 +674,42 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
}
// buildConnectionStringForDB builds connection string for specific database
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
sslMode := "disable"
if p.IsHttps {
sslMode = "require"
}
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
p.Host,
p.Port,
p.Username,
p.Password,
password,
dbName,
sslMode,
)
}
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
if len(extensions) == 0 {
return nil
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, error) {
if encryptor == nil {
return password, nil
}
if p.Database == nil || *p.Database == "" {
return errors.New("database name is required for installing extensions")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Build connection string for the specific database
connStr := buildConnectionStringForDB(p, *p.Database)
// Connect to database
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
fmt.Println("failed to close connection: %w", closeErr)
}
}()
// Check which extensions are already installed
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
if err != nil {
return fmt.Errorf("failed to check installed extensions: %w", err)
}
// Install missing extensions
for _, extension := range extensions {
if contains(installedExtensions, string(extension)) {
continue // Extension already installed
}
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
}
}
return nil
return encryptor.Decrypt(databaseID, password)
}
// getInstalledExtensions queries the database for currently installed extensions
func (p *PostgresqlDatabase) getInstalledExtensions(
ctx context.Context,
conn *pgx.Conn,
) ([]string, error) {
query := "SELECT extname FROM pg_extension"
rows, err := conn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
}
defer rows.Close()
var extensions []string
for rows.Next() {
var extname string
if err := rows.Scan(&extname); err != nil {
return nil, fmt.Errorf("failed to scan extension name: %w", err)
}
extensions = append(extensions, extname)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
}
return extensions, nil
func isSupabaseConnection(host, username string) bool {
return strings.Contains(strings.ToLower(host), "supabase") ||
strings.Contains(strings.ToLower(username), "supabase")
}
// installExtension installs a single PostgreSQL extension
func (p *PostgresqlDatabase) installExtension(
ctx context.Context,
conn *pgx.Conn,
extensionName string,
) error {
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
_, err := conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
func extractSupabaseProjectID(username string) string {
if idx := strings.Index(username, "."); idx != -1 {
return username[idx+1:]
}
return nil
}
// contains checks if a string slice contains a specific string
func contains(slice []string, item string) bool {
return slices.Contains(slice, item)
return ""
}

View File

@@ -0,0 +1,505 @@
package postgresql
import (
"context"
"fmt"
"log/slog"
"os"
"strconv"
"strings"
"testing"
"github.com/google/uuid"
"github.com/jmoiron/sqlx"
_ "github.com/lib/pq"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/config"
"postgresus-backend/internal/util/tools"
)
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.False(t, isReadOnly, "Admin user should not be read-only")
})
}
}
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS readonly_test CASCADE;
DROP TABLE IF EXISTS hack_table CASCADE;
DROP TABLE IF EXISTS future_table CASCADE;
CREATE TABLE readonly_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "postgresus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.True(t, isReadOnly, "Created user should be read-only")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
_, err = container.DB.Exec(`
CREATE TABLE future_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO future_table (data) VALUES ('future_data');
`)
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var data string
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "future_data", data)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS schema_a;
CREATE SCHEMA IF NOT EXISTS schema_b;
CREATE TABLE schema_a.table_a (id INT, data TEXT);
CREATE TABLE schema_b.table_b (id INT, data TEXT);
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, container.Database)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var dataA string
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_a", dataA)
var dataB string
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "data_b", dataB)
// Clean up: Drop user with CASCADE to handle default privilege dependencies
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
dashDbName := "test-db-with-dash"
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
}()
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, container.Username, container.Password, dashDbName)
dashDB, err := sqlx.Connect("postgres", dashDSN)
assert.NoError(t, err)
defer dashDB.Close()
_, err = dashDB.Exec(`
CREATE TABLE dash_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum("16"),
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &dashDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "postgresus-"))
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host, container.Port, username, password, dashDbName)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
env := config.GetEnv()
if env.TestSupabaseHost == "" {
t.Skip("Skipping Supabase test: missing environment variables")
}
portInt, err := strconv.Atoi(env.TestSupabasePort)
assert.NoError(t, err)
dsn := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
env.TestSupabaseUsername,
env.TestSupabasePassword,
env.TestSupabaseDatabase,
)
adminDB, err := sqlx.Connect("postgres", dsn)
assert.NoError(t, err)
defer adminDB.Close()
tableName := fmt.Sprintf(
"readonly_test_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = adminDB.Exec(fmt.Sprintf(`
DROP TABLE IF EXISTS public.%s CASCADE;
CREATE TABLE public.%s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
`, tableName, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
}()
pgModel := &PostgresqlDatabase{
Host: env.TestSupabaseHost,
Port: portInt,
Username: env.TestSupabaseUsername,
Password: env.TestSupabasePassword,
Database: &env.TestSupabaseDatabase,
IsHttps: true,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, connectionUsername)
assert.NotEmpty(t, newPassword)
assert.True(t, strings.HasPrefix(connectionUsername, "postgresus-"))
baseUsername := connectionUsername
if idx := strings.Index(connectionUsername, "."); idx != -1 {
baseUsername = connectionUsername[:idx]
}
defer func() {
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
}()
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
env.TestSupabaseHost,
portInt,
connectionUsername,
newPassword,
env.TestSupabaseDatabase,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
}
type PostgresContainer struct {
Host string
Port int
Username string
Password string
Database string
DB *sqlx.DB
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, portInt, username, password, dbName)
db, err := sqlx.Connect("postgres", dsn)
assert.NoError(t, err)
var versionStr string
err = db.Get(&versionStr, "SELECT version()")
assert.NoError(t, err)
return &PostgresContainer{
Host: host,
Port: portInt,
Username: username,
Password: password,
Database: dbName,
DB: db,
}
}
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
var versionStr string
err := container.DB.Get(&versionStr, "SELECT version()")
if err != nil {
return nil
}
version := extractPostgresVersion(versionStr)
return &PostgresqlDatabase{
Version: version,
Host: container.Host,
Port: container.Port,
Username: container.Username,
Password: container.Password,
Database: &container.Database,
IsHttps: false,
}
}
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
if strings.Contains(versionStr, "PostgreSQL 12") {
return tools.GetPostgresqlVersionEnum("12")
} else if strings.Contains(versionStr, "PostgreSQL 13") {
return tools.GetPostgresqlVersionEnum("13")
} else if strings.Contains(versionStr, "PostgreSQL 14") {
return tools.GetPostgresqlVersionEnum("14")
} else if strings.Contains(versionStr, "PostgreSQL 15") {
return tools.GetPostgresqlVersionEnum("15")
} else if strings.Contains(versionStr, "PostgreSQL 16") {
return tools.GetPostgresqlVersionEnum("16")
} else if strings.Contains(versionStr, "PostgreSQL 17") {
return tools.GetPostgresqlVersionEnum("17")
}
return tools.GetPostgresqlVersionEnum("16")
}

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -19,6 +20,7 @@ var databaseService = &DatabaseService{
[]DatabaseCopyListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var databaseController = &DatabaseController{

View File

@@ -0,0 +1,10 @@
package databases
type CreateReadOnlyUserResponse struct {
Username string `json:"username"`
Password string `json:"password"`
}
type IsReadOnlyResponse struct {
IsReadOnly bool `json:"isReadOnly"`
}

View File

@@ -2,6 +2,7 @@ package databases
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -11,7 +12,11 @@ type DatabaseValidator interface {
}
type DatabaseConnector interface {
TestConnection(logger *slog.Logger) error
TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error
HideSensitiveData()
}

View File

@@ -5,6 +5,7 @@ import (
"log/slog"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -56,17 +57,38 @@ func (d *Database) ValidateUpdate(old, new Database) error {
return nil
}
func (d *Database) TestConnection(logger *slog.Logger) error {
return d.getSpecificDatabase().TestConnection(logger)
func (d *Database) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
if d.Postgresql != nil {
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
}
return nil
}
func (d *Database) PopulateVersionIfEmpty(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
if d.Postgresql != nil {
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
}
return nil
}
func (d *Database) Update(incoming *Database) {
d.Name = incoming.Name
d.Type = incoming.Type
d.Notifiers = incoming.Notifiers
switch d.Type {
case DatabaseTypePostgres:

View File

@@ -1,6 +1,7 @@
package databases
import (
"context"
"errors"
"fmt"
"log/slog"
@@ -11,6 +12,7 @@ import (
"postgresus-backend/internal/features/notifiers"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -26,6 +28,7 @@ type DatabaseService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *DatabaseService) AddDbCreationListener(
@@ -65,6 +68,14 @@ func (s *DatabaseService) CreateDatabase(
return nil, err
}
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
database, err = s.dbRepository.Save(database)
if err != nil {
return nil, err
@@ -118,6 +129,14 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
}
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
_, err = s.dbRepository.Save(existingDatabase)
if err != nil {
return err
@@ -219,7 +238,6 @@ func (s *DatabaseService) GetDatabasesByWorkspace(
func (s *DatabaseService) IsNotifierUsing(
user *users_models.User,
workspaceID uuid.UUID,
notifierID uuid.UUID,
) (bool, error) {
_, err := s.notifierService.GetNotifier(user, notifierID)
@@ -251,7 +269,7 @@ func (s *DatabaseService) TestDatabaseConnection(
return errors.New("insufficient permissions to test connection for this database")
}
err = database.TestConnection(s.logger)
err = database.TestConnection(s.logger, s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
database.LastBackupErrorMessage = &lastSaveError
@@ -295,7 +313,7 @@ func (s *DatabaseService) TestDatabaseConnectionDirect(
usingDatabase = database
}
return usingDatabase.TestConnection(s.logger)
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) GetDatabaseByID(
@@ -447,3 +465,148 @@ func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error
return nil
}
func (s *DatabaseService) IsUserReadOnly(
user *users_models.User,
database *Database,
) (bool, error) {
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return false, err
}
if existingDatabase.WorkspaceID == nil {
return false, errors.New("cannot check user for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*existingDatabase.WorkspaceID,
user,
)
if err != nil {
return false, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this database")
}
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
return false, errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return false, err
}
usingDatabase = existingDatabase
} else {
if database.WorkspaceID != nil {
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return false, err
}
if !canAccess {
return false, errors.New("insufficient permissions to access this workspace")
}
}
usingDatabase = database
}
if usingDatabase.Type != DatabaseTypePostgres {
return false, errors.New("read-only check only supported for PostgreSQL databases")
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return usingDatabase.Postgresql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
}
func (s *DatabaseService) CreateReadOnlyUser(
user *users_models.User,
database *Database,
) (string, string, error) {
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return "", "", err
}
if existingDatabase.WorkspaceID == nil {
return "", "", errors.New("cannot create user for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return "", "", err
}
if !canManage {
return "", "", errors.New("insufficient permissions to manage this database")
}
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
return "", "", errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return "", "", err
}
usingDatabase = existingDatabase
} else {
if database.WorkspaceID != nil {
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return "", "", err
}
if !canManage {
return "", "", errors.New("insufficient permissions to manage this workspace")
}
}
usingDatabase = database
}
if usingDatabase.Type != DatabaseTypePostgres {
return "", "", errors.New("read-only user creation only supported for PostgreSQL")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
username, password, err := usingDatabase.Postgresql.CreateReadOnlyUser(
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
)
if err != nil {
return "", "", err
}
if usingDatabase.WorkspaceID != nil {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Read-only user created for database: %s (username: %s)",
usingDatabase.Name,
username,
),
&user.ID,
usingDatabase.WorkspaceID,
)
}
return username, password, nil
}

View File

@@ -0,0 +1,9 @@
package secrets
var secretKeyService = &SecretKeyService{
nil,
}
func GetSecretKeyService() *SecretKeyService {
return secretKeyService
}

View File

@@ -0,0 +1 @@
package secrets

View File

@@ -0,0 +1,73 @@
package secrets
import (
"errors"
"fmt"
"os"
"postgresus-backend/internal/config"
user_models "postgresus-backend/internal/features/users/models"
"postgresus-backend/internal/storage"
"github.com/google/uuid"
"gorm.io/gorm"
)
type SecretKeyService struct {
cachedKey *string
}
func (s *SecretKeyService) MigrateKeyFromDbToFileIfExist() error {
var secretKey user_models.SecretKey
err := storage.GetDb().First(&secretKey).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("failed to check for secret key in database: %w", err)
}
if secretKey.Secret == "" {
return nil
}
secretKeyPath := config.GetEnv().SecretKeyPath
if err := os.WriteFile(secretKeyPath, []byte(secretKey.Secret), 0600); err != nil {
return fmt.Errorf("failed to write secret key to file: %w", err)
}
if err := storage.GetDb().Exec("DELETE FROM secret_keys").Error; err != nil {
return fmt.Errorf("failed to delete secret key from database: %w", err)
}
return nil
}
func (s *SecretKeyService) GetSecretKey() (string, error) {
if s.cachedKey != nil {
return *s.cachedKey, nil
}
secretKeyPath := config.GetEnv().SecretKeyPath
data, err := os.ReadFile(secretKeyPath)
if err != nil {
if os.IsNotExist(err) {
newKey := s.generateNewSecretKey()
if err := os.WriteFile(secretKeyPath, []byte(newKey), 0600); err != nil {
return "", fmt.Errorf("failed to write new secret key: %w", err)
}
s.cachedKey = &newKey
return newKey, nil
}
return "", fmt.Errorf("failed to read secret key file: %w", err)
}
key := string(data)
s.cachedKey = &key
return key, nil
}
func (s *SecretKeyService) generateNewSecretKey() string {
return uuid.New().String() + uuid.New().String()
}

View File

@@ -453,70 +453,6 @@ func Test_CrossWorkspaceSecurity_CannotAccessNotifierFromAnotherWorkspace(t *tes
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
@@ -553,7 +489,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-bot-token-12345", notifier.TelegramNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "original-bot-token-12345", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.TelegramNotifier.BotToken)
@@ -592,7 +534,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-password-secret", notifier.EmailNotifier.SMTPPassword)
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "original-password-secret", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.EmailNotifier.SMTPPassword)
@@ -625,7 +573,13 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "xoxb-original-slack-token", notifier.SlackNotifier.BotToken)
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "xoxb-original-slack-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.SlackNotifier.BotToken)
@@ -656,11 +610,17 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(
assert.True(
t,
"https://discord.com/api/webhooks/123/original-token",
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/original-token", decrypted)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.DiscordNotifier.ChannelWebhookURL)
@@ -691,10 +651,16 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted in DB",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(
t,
"https://outlook.office.com/webhook/original-token",
notifier.TeamsNotifier.WebhookURL,
decrypted,
)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
@@ -813,3 +779,263 @@ func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_CreateNotifier_AllSensitiveFieldsEncryptedInDB(t *testing.T) {
testCases := []struct {
name string
createNotifier func(workspaceID uuid.UUID) *Notifier
verifySensitiveEncryption func(t *testing.T, notifier *Notifier)
}{
{
name: "Telegram Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram",
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: "plain-telegram-token-123",
TargetChatID: "123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TelegramNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TelegramNotifier.BotToken)
assert.Equal(t, "plain-telegram-token-123", decrypted)
},
},
{
name: "Email Notifier - SMTPPassword encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Email",
NotifierType: NotifierTypeEmail,
EmailNotifier: &email_notifier.EmailNotifier{
TargetEmail: "test@example.com",
SMTPHost: "smtp.example.com",
SMTPPort: 587,
SMTPUser: "user@example.com",
SMTPPassword: "plain-smtp-password-456",
From: "noreply@example.com",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.EmailNotifier.SMTPPassword),
"SMTPPassword should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.EmailNotifier.SMTPPassword)
assert.Equal(t, "plain-smtp-password-456", decrypted)
},
},
{
name: "Slack Notifier - BotToken encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Slack",
NotifierType: NotifierTypeSlack,
SlackNotifier: &slack_notifier.SlackNotifier{
BotToken: "plain-slack-token-789",
TargetChatID: "C0123456789",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.SlackNotifier.BotToken),
"BotToken should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.SlackNotifier.BotToken)
assert.Equal(t, "plain-slack-token-789", decrypted)
},
},
{
name: "Discord Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Discord",
NotifierType: NotifierTypeDiscord,
DiscordNotifier: &discord_notifier.DiscordNotifier{
ChannelWebhookURL: "https://discord.com/api/webhooks/123/abc",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.DiscordNotifier.ChannelWebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(
t,
notifier.ID,
notifier.DiscordNotifier.ChannelWebhookURL,
)
assert.Equal(t, "https://discord.com/api/webhooks/123/abc", decrypted)
},
},
{
name: "Teams Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Teams",
NotifierType: NotifierTypeTeams,
TeamsNotifier: &teams_notifier.TeamsNotifier{
WebhookURL: "https://outlook.office.com/webhook/test123",
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.TeamsNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.TeamsNotifier.WebhookURL)
assert.Equal(t, "https://outlook.office.com/webhook/test123", decrypted)
},
},
{
name: "Webhook Notifier - WebhookURL encrypted",
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Webhook",
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test456",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
},
verifySensitiveEncryption: func(t *testing.T, notifier *Notifier) {
assert.True(
t,
isEncrypted(notifier.WebhookNotifier.WebhookURL),
"WebhookURL should be encrypted",
)
decrypted := decryptField(t, notifier.ID, notifier.WebhookNotifier.WebhookURL)
assert.Equal(t, "https://webhook.example.com/test456", decrypted)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Create notifier via API (plaintext credentials)
var createdNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
tc.createNotifier(workspace.ID),
http.StatusOK,
&createdNotifier,
)
// Read from DB directly (bypass service layer)
repository := &NotifierRepository{}
notifierFromDB, err := repository.FindByID(createdNotifier.ID)
assert.NoError(t, err)
// Verify encryption
tc.verifySensitiveEncryption(t, notifierFromDB)
// Cleanup
deleteNotifier(t, router, createdNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func isEncrypted(value string) bool {
return len(value) > 4 && value[:4] == "enc:"
}
func decryptField(t *testing.T, notifierID uuid.UUID, encryptedValue string) string {
encryptor := GetNotifierService().fieldEncryptor
decrypted, err := encryptor.Decrypt(notifierID, encryptedValue)
assert.NoError(t, err)
return decrypted
}

View File

@@ -3,6 +3,7 @@ package notifiers
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -12,6 +13,7 @@ var notifierService = &NotifierService{
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var notifierController = &NotifierController{
notifierService,
@@ -26,6 +28,9 @@ func GetNotifierService() *NotifierService {
return notifierService
}
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
}

View File

@@ -1,11 +1,21 @@
package notifiers
import "log/slog"
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
)
type NotificationSender interface {
Send(logger *slog.Logger, heading string, message string) error
Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -9,6 +9,7 @@ import (
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
return "notifiers"
}
func (n *Notifier) Validate() error {
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.Name == "" {
return errors.New("name is required")
}
return n.getSpecificNotifier().Validate()
return n.getSpecificNotifier().Validate(encryptor)
}
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
err := n.getSpecificNotifier().Send(logger, heading, message)
func (n *Notifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
if err != nil {
lastSendError := err.Error()
@@ -58,6 +64,10 @@ func (n *Notifier) HideSensitiveData() {
n.getSpecificNotifier().HideSensitiveData()
}
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
}
func (n *Notifier) Update(incoming *Notifier) {
n.Name = incoming.Name
n.NotifierType = incoming.NotifierType

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
return "discord_notifiers"
}
func (d *DiscordNotifier) Validate() error {
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
return nil
}
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (d *DiscordNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
@@ -81,3 +92,14 @@ func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
d.ChannelWebhookURL = incoming.ChannelWebhookURL
}
}
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL != "" {
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
d.ChannelWebhookURL = encrypted
}
return nil
}

View File

@@ -0,0 +1,28 @@
package email_notifier
import (
"errors"
"net/smtp"
)
type loginAuth struct {
username, password string
}
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:":
return []byte(a.username), nil
case "Password:":
return []byte(a.password), nil
default:
return nil, errors.New("unknown LOGIN challenge: " + string(fromServer))
}
}
return nil, nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net"
"net/smtp"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -27,13 +28,14 @@ type EmailNotifier struct {
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
}
func (e *EmailNotifier) TableName() string {
return "email_notifiers"
}
func (e *EmailNotifier) Validate() error {
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if e.TargetEmail == "" {
return errors.New("target email is required")
}
@@ -54,159 +56,36 @@ func (e *EmailNotifier) Validate() error {
return nil
}
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
// Compose email
from := e.SMTPUser
func (e *EmailNotifier) Send(
encryptor encryption.FieldEncryptor,
_ *slog.Logger,
heading string,
message string,
) error {
var smtpPassword string
if e.SMTPPassword != "" {
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
}
smtpPassword = decrypted
}
from := e.From
if from == "" {
from = "noreply@" + e.SMTPHost
from = e.SMTPUser
if from == "" {
from = "noreply@" + e.SMTPHost
}
}
to := []string{e.TargetEmail}
emailContent := e.buildEmailContent(heading, message, from)
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
// Format the email content
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
body := message
fromHeader := fmt.Sprintf("From: %s\r\n", from)
// Combine all parts of the email
emailContent := []byte(fromHeader + subject + mime + body)
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
timeout := DefaultTimeout
// Determine if authentication is required
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
// Handle different port scenarios
if e.SMTPPort == ImplicitTLSPort {
// Implicit TLS (port 465)
// Set up TLS config
tlsConfig := &tls.Config{
ServerName: e.SMTPHost,
}
// Dial with timeout
dialer := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return fmt.Errorf("failed to connect to SMTP server: %w", err)
}
defer func() {
_ = conn.Close()
}()
// Create SMTP client
client, err := smtp.NewClient(conn, e.SMTPHost)
if err != nil {
return fmt.Errorf("failed to create SMTP client: %w", err)
}
defer func() {
_ = client.Quit()
}()
// Set up authentication only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
}
// Set sender and recipients
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
for _, recipient := range to {
if err := client.Rcpt(recipient); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
}
// Send the email body
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
_, err = writer.Write(emailContent)
if err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
err = writer.Close()
if err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
} else {
// STARTTLS (port 587) or other ports
// Create a custom dialer with timeout
dialer := &net.Dialer{Timeout: timeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return fmt.Errorf("failed to connect to SMTP server: %w", err)
}
// Create client from connection
client, err := smtp.NewClient(conn, e.SMTPHost)
if err != nil {
return fmt.Errorf("failed to create SMTP client: %w", err)
}
defer func() {
_ = client.Quit()
}()
// Send email using the client
if err := client.Hello(DefaultHelloName); err != nil {
return fmt.Errorf("SMTP hello failed: %w", err)
}
// Start TLS if available
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: e.SMTPHost}); err != nil {
return fmt.Errorf("STARTTLS failed: %w", err)
}
}
// Authenticate only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
}
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
for _, recipient := range to {
if err := client.Rcpt(recipient); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
}
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
_, err = writer.Write(emailContent)
if err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
err = writer.Close()
if err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return client.Quit()
return e.sendImplicitTLS(emailContent, from, smtpPassword, isAuthRequired)
}
return e.sendStartTLS(emailContent, from, smtpPassword, isAuthRequired)
}
func (e *EmailNotifier) HideSensitiveData() {
@@ -218,8 +97,183 @@ func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.SMTPHost = incoming.SMTPHost
e.SMTPPort = incoming.SMTPPort
e.SMTPUser = incoming.SMTPUser
e.From = incoming.From
if incoming.SMTPPassword != "" {
e.SMTPPassword = incoming.SMTPPassword
}
}
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if e.SMTPPassword != "" {
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
}
e.SMTPPassword = encrypted
}
return nil
}
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
fromHeader := fmt.Sprintf("From: %s\r\n", from)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + mime + message)
}
func (e *EmailNotifier) sendImplicitTLS(
emailContent []byte,
from string,
password string,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return e.createImplicitTLSClient()
}
client, cleanup, err := e.authenticateWithRetry(createClient, password, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return e.sendEmail(client, from, emailContent)
}
func (e *EmailNotifier) sendStartTLS(
emailContent []byte,
from string,
password string,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return e.createStartTLSClient()
}
client, cleanup, err := e.authenticateWithRetry(createClient, password, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return e.sendEmail(client, from, emailContent)
}
func (e *EmailNotifier) createImplicitTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
tlsConfig := &tls.Config{ServerName: e.SMTPHost}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, e.SMTPHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
return client, func() { _ = client.Quit() }, nil
}
func (e *EmailNotifier) createStartTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, e.SMTPHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: e.SMTPHost}); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
}
}
return client, func() { _ = client.Quit() }, nil
}
func (e *EmailNotifier) authenticateWithRetry(
createClient func() (*smtp.Client, func(), error),
password string,
isAuthRequired bool,
) (*smtp.Client, func(), error) {
client, cleanup, err := createClient()
if err != nil {
return nil, nil, err
}
if !isAuthRequired {
return client, cleanup, nil
}
// Try PLAIN auth first
plainAuth := smtp.PlainAuth("", e.SMTPUser, password, e.SMTPHost)
if err := client.Auth(plainAuth); err == nil {
return client, cleanup, nil
}
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
cleanup()
client, cleanup, err = createClient()
if err != nil {
return nil, nil, err
}
loginAuth := &loginAuth{username: e.SMTPUser, password: password}
if err := client.Auth(loginAuth); err != nil {
cleanup()
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
}
return client, cleanup, nil
}
func (e *EmailNotifier) sendEmail(client *smtp.Client, from string, content []byte) error {
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
if err := client.Rcpt(e.TargetEmail); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
if _, err = writer.Write(content); err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
if err = writer.Close(); err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
"time"
@@ -23,7 +24,7 @@ type SlackNotifier struct {
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
func (s *SlackNotifier) Validate() error {
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if s.BotToken == "" {
return errors.New("bot token is required")
}
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
return nil
}
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
func (s *SlackNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
full := fmt.Sprintf("*%s*", heading)
if message != "" {
@@ -60,6 +70,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
maxAttempts = 5
defaultBackoff = 2 * time.Second // when Retry-After header missing
backoffMultiplier = 1.5 // use exponential growth
requestTimeout = 30 * time.Second
)
var (
@@ -67,6 +78,10 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
attempts = 0
)
client := &http.Client{
Timeout: requestTimeout,
}
for {
attempts++
@@ -80,9 +95,9 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Authorization", "Bearer "+s.BotToken)
req.Header.Set("Authorization", "Bearer "+botToken)
resp, err := http.DefaultClient.Do(req)
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("send slack message: %w", err)
}
@@ -144,3 +159,14 @@ func (s *SlackNotifier) Update(incoming *SlackNotifier) {
s.BotToken = incoming.BotToken
}
}
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if s.BotToken != "" {
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
s.BotToken = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
return "teams_notifiers"
}
func (n *TeamsNotifier) Validate() error {
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL == "" {
return errors.New("webhook_url is required")
}
u, err := url.Parse(n.WebhookURL)
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
u, err := url.Parse(webhookURL)
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
return errors.New("invalid webhook_url")
}
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
}
type cardAttachment struct {
ContentType string `json:"contentType"`
Content interface{} `json:"content"`
ContentType string `json:"contentType"`
Content any `json:"content"`
}
type payload struct {
@@ -43,11 +50,20 @@ type payload struct {
Attachments []cardAttachment `json:"attachments,omitempty"`
}
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
if err := n.Validate(); err != nil {
func (n *TeamsNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
if err := n.Validate(encryptor); err != nil {
return err
}
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
card := map[string]any{
"type": "AdaptiveCard",
"version": "1.4",
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
}
body, _ := json.Marshal(p)
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return err
}
@@ -104,3 +120,14 @@ func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
n.WebhookURL = incoming.WebhookURL
}
}
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
n.WebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
return "telegram_notifiers"
}
func (t *TelegramNotifier) Validate() error {
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.BotToken == "" {
return errors.New("bot token is required")
}
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
return nil
}
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *TelegramNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
}
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
data := url.Values{}
data.Set("chat_id", t.TargetChatID)
@@ -93,3 +104,14 @@ func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
t.BotToken = incoming.BotToken
}
}
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.BotToken != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
t.BotToken = encrypted
}
return nil
}

View File

@@ -9,21 +9,59 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"strings"
"github.com/google/uuid"
"gorm.io/gorm"
)
type WebhookHeader struct {
Key string `json:"key"`
Value string `json:"value"`
}
type WebhookNotifier struct {
NotifierID uuid.UUID `json:"notifierId" gorm:"primaryKey;column:notifier_id"`
WebhookURL string `json:"webhookUrl" gorm:"not null;column:webhook_url"`
WebhookMethod WebhookMethod `json:"webhookMethod" gorm:"not null;column:webhook_method"`
BodyTemplate *string `json:"bodyTemplate" gorm:"column:body_template;type:text"`
HeadersJSON string `json:"-" gorm:"column:headers;type:text"`
Headers []WebhookHeader `json:"headers" gorm:"-"`
}
func (t *WebhookNotifier) TableName() string {
return "webhook_notifiers"
}
func (t *WebhookNotifier) Validate() error {
func (t *WebhookNotifier) BeforeSave(_ *gorm.DB) error {
if len(t.Headers) > 0 {
data, err := json.Marshal(t.Headers)
if err != nil {
return err
}
t.HeadersJSON = string(data)
} else {
t.HeadersJSON = "[]"
}
return nil
}
func (t *WebhookNotifier) AfterFind(_ *gorm.DB) error {
if t.HeadersJSON != "" {
if err := json.Unmarshal([]byte(t.HeadersJSON), &t.Headers); err != nil {
return err
}
}
return nil
}
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -35,69 +73,22 @@ func (t *WebhookNotifier) Validate() error {
return nil
}
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *WebhookNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
switch t.WebhookMethod {
case WebhookMethodGET:
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
t.WebhookURL,
url.QueryEscape(heading),
url.QueryEscape(message),
)
resp, err := http.Get(reqURL)
if err != nil {
return fmt.Errorf("failed to send GET webhook: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
logger.Error("failed to close response body", "error", cerr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf(
"webhook GET returned status: %s, body: %s",
resp.Status,
string(body),
)
}
return nil
return t.sendGET(webhookURL, heading, message, logger)
case WebhookMethodPOST:
payload := map[string]string{
"heading": heading,
"message": message,
}
body, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to send POST webhook: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
logger.Error("failed to close response body", "error", cerr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf(
"webhook POST returned status: %s, body: %s",
resp.Status,
string(body),
)
}
return nil
return t.sendPOST(webhookURL, heading, message, logger)
default:
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
@@ -109,4 +100,144 @@ func (t *WebhookNotifier) HideSensitiveData() {
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
t.WebhookURL = incoming.WebhookURL
t.WebhookMethod = incoming.WebhookMethod
t.BodyTemplate = incoming.BodyTemplate
t.Headers = incoming.Headers
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
t.WebhookURL = encrypted
}
return nil
}
func (t *WebhookNotifier) sendGET(webhookURL, heading, message string, logger *slog.Logger) error {
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
webhookURL,
url.QueryEscape(heading),
url.QueryEscape(message),
)
req, err := http.NewRequest(http.MethodGet, reqURL, nil)
if err != nil {
return fmt.Errorf("failed to create GET request: %w", err)
}
t.applyHeaders(req)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to send GET webhook: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
logger.Error("failed to close response body", "error", cerr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
body, _ := io.ReadAll(resp.Body)
return fmt.Errorf(
"webhook GET returned status: %s, body: %s",
resp.Status,
string(body),
)
}
return nil
}
func (t *WebhookNotifier) sendPOST(webhookURL, heading, message string, logger *slog.Logger) error {
body := t.buildRequestBody(heading, message)
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to create POST request: %w", err)
}
hasContentType := false
for _, h := range t.Headers {
if strings.EqualFold(h.Key, "Content-Type") {
hasContentType = true
break
}
}
if !hasContentType {
req.Header.Set("Content-Type", "application/json")
}
t.applyHeaders(req)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to send POST webhook: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
logger.Error("failed to close response body", "error", cerr)
}
}()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
respBody, _ := io.ReadAll(resp.Body)
return fmt.Errorf(
"webhook POST returned status: %s, body: %s",
resp.Status,
string(respBody),
)
}
return nil
}
func (t *WebhookNotifier) buildRequestBody(heading, message string) []byte {
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
result := *t.BodyTemplate
result = strings.ReplaceAll(result, "{{heading}}", escapeJSONString(heading))
result = strings.ReplaceAll(result, "{{message}}", escapeJSONString(message))
return []byte(result)
}
payload := map[string]string{
"heading": heading,
"message": message,
}
body, _ := json.Marshal(payload)
return body
}
func (t *WebhookNotifier) applyHeaders(req *http.Request) {
for _, h := range t.Headers {
if h.Key != "" {
req.Header.Set(h.Key, h.Value)
}
}
}
func escapeJSONString(s string) string {
b, err := json.Marshal(s)
if err != nil || len(b) < 2 {
escaped := strings.ReplaceAll(s, `\`, `\\`)
escaped = strings.ReplaceAll(escaped, `"`, `\"`)
escaped = strings.ReplaceAll(escaped, "\n", `\n`)
escaped = strings.ReplaceAll(escaped, "\r", `\r`)
escaped = strings.ReplaceAll(escaped, "\t", `\t`)
return escaped
}
return string(b[1 : len(b)-1])
}

View File

@@ -165,7 +165,6 @@ func (r *NotifierRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Notifi
func (r *NotifierRepository) Delete(notifier *Notifier) error {
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
switch notifier.NotifierType {
case NotifierTypeTelegram:
if notifier.TelegramNotifier != nil {

View File

@@ -8,6 +8,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -17,6 +18,7 @@ type NotifierService struct {
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *NotifierService) SaveNotifier(
@@ -46,7 +48,11 @@ func (s *NotifierService) SaveNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -63,7 +69,11 @@ func (s *NotifierService) SaveNotifier(
} else {
notifier.WorkspaceID = workspaceID
if err := notifier.Validate(); err != nil {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := notifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -175,7 +185,7 @@ func (s *NotifierService) SendTestNotification(
return errors.New("insufficient permissions to test notifier in this workspace")
}
err = notifier.Send(s.logger, "Test message", "This is a test message")
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
if err != nil {
return err
}
@@ -205,16 +215,24 @@ func (s *NotifierService) SendTestNotificationToNotifier(
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = existingNotifier
} else {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = notifier
}
return usingNotifier.Send(s.logger, "Test message", "This is a test message")
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
}
func (s *NotifierService) SendNotification(
@@ -233,7 +251,7 @@ func (s *NotifierService) SendNotification(
return
}
err = notifiedFromDb.Send(s.logger, title, message)
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
if err != nil {
errMsg := err.Error()
notifiedFromDb.LastSendError = &errMsg

View File

@@ -1,6 +1,7 @@
package restores
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -29,6 +30,7 @@ import (
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
util_encryption "postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
@@ -169,6 +171,36 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
IsExcludeExtensions: true,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -309,6 +341,7 @@ func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
@@ -323,9 +356,7 @@ func createTestBackup(
backup := &backups.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
Database: database,
StorageID: storages[0].ID,
Storage: storages[0],
Status: backups.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
@@ -340,7 +371,7 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
if err := storages[0].SaveFile(context.Background(), fieldEncryptor, logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}

View File

@@ -8,6 +8,7 @@ import (
"postgresus-backend/internal/features/restores/usecases"
"postgresus-backend/internal/features/storages"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -22,6 +23,7 @@ var restoreService = &RestoreService{
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var restoreController = &RestoreController{
restoreService,

View File

@@ -2,7 +2,6 @@ package models
import (
"postgresus-backend/internal/features/backups/backups"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/restores/enums"
"time"
@@ -16,8 +15,6 @@ type Restore struct {
BackupID uuid.UUID `json:"backupId" gorm:"column:backup_id;type:uuid;not null"`
Backup *backups.Backup
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:RestoreID"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
RestoreDurationMs int64 `json:"restoreDurationMs" gorm:"column:restore_duration_ms;default:0"`

View File

@@ -32,7 +32,6 @@ func (r *RestoreRepository) FindByBackupID(backupID uuid.UUID) ([]*models.Restor
if err := storage.
GetDb().
Preload("Backup").
Preload("Postgresql").
Where("backup_id = ?", backupID).
Order("created_at DESC").
Find(&restores).Error; err != nil {
@@ -48,7 +47,6 @@ func (r *RestoreRepository) FindByID(id uuid.UUID) (*models.Restore, error) {
if err := storage.
GetDb().
Preload("Backup").
Preload("Postgresql").
Where("id = ?", id).
First(&restore).Error; err != nil {
return nil, err
@@ -62,10 +60,7 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.
if err := storage.
GetDb().
Preload("Backup.Storage").
Preload("Backup.Database").
Preload("Backup").
Preload("Postgresql").
Where("status = ?", status).
Order("created_at DESC").
Find(&restores).Error; err != nil {

View File

@@ -14,6 +14,7 @@ import (
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"time"
@@ -30,6 +31,7 @@ type RestoreService struct {
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
@@ -62,12 +64,17 @@ func (s *RestoreService) GetRestores(
return nil, err
}
if backup.Database.WorkspaceID == nil {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot get restores for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*backup.Database.WorkspaceID,
*database.WorkspaceID,
user,
)
if err != nil {
@@ -90,12 +97,17 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if backup.Database.WorkspaceID == nil {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot restore backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*backup.Database.WorkspaceID,
*database.WorkspaceID,
user,
)
if err != nil {
@@ -110,12 +122,6 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
fmt.Printf(
"restore from %s to %s\n",
backupDatabase.Postgresql.Version,
requestDTO.PostgresqlDatabase.Version,
)
if tools.IsBackupDbVersionHigherThanRestoreDbVersion(
backupDatabase.Postgresql.Version,
requestDTO.PostgresqlDatabase.Version,
@@ -135,10 +141,10 @@ func (s *RestoreService) RestoreBackupWithAuth(
fmt.Sprintf(
"Database restored from backup %s for database: %s",
backupID.String(),
backup.Database.Name,
database.Name,
),
&user.ID,
backup.Database.WorkspaceID,
database.WorkspaceID,
)
return nil
@@ -152,7 +158,12 @@ func (s *RestoreService) RestoreBackup(
return errors.New("backup is not completed")
}
if backup.Database.Type == databases.DatabaseTypePostgres {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.Type == databases.DatabaseTypePostgres {
if requestDTO.PostgresqlDatabase == nil {
return errors.New("postgresql database is required")
}
@@ -176,15 +187,9 @@ func (s *RestoreService) RestoreBackup(
return err
}
// Set the RestoreID on the PostgreSQL database and save it
if requestDTO.PostgresqlDatabase != nil {
requestDTO.PostgresqlDatabase.RestoreID = &restore.ID
restore.Postgresql = requestDTO.PostgresqlDatabase
// Save the restore again to include the postgresql database
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
// Save the restore again to include the postgresql database
if err := s.restoreRepository.Save(&restore); err != nil {
return err
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
@@ -193,7 +198,7 @@ func (s *RestoreService) RestoreBackup(
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
backup.Database.ID,
database.ID,
)
if err != nil {
return err
@@ -201,11 +206,27 @@ func (s *RestoreService) RestoreBackup(
start := time.Now().UTC()
restoringToDB := &databases.Database{
Postgresql: requestDTO.PostgresqlDatabase,
}
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to auto-detect database version: %w", err)
}
isExcludeExtensions := false
if requestDTO.PostgresqlDatabase != nil {
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
}
err = s.restoreBackupUsecase.Execute(
backupConfig,
restore,
database,
restoringToDB,
backup,
storage,
isExcludeExtensions,
)
if err != nil {
errMsg := err.Error()

View File

@@ -1,11 +1,13 @@
package usecases_postgresql
import (
"postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/util/logger"
)
var restorePostgresqlBackupUsecase = &RestorePostgresqlBackupUsecase{
logger.GetLogger(),
secrets.GetSecretKeyService(),
}
func GetRestorePostgresqlBackupUsecase() *RestorePostgresqlBackupUsecase {

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -15,11 +16,14 @@ import (
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
util_encryption "postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/tools"
@@ -27,16 +31,20 @@ import (
)
type RestorePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyService *encryption_secrets.SecretKeyService
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
originalDB *databases.Database,
restoringToDB *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
if backup.Database.Type != databases.DatabaseTypePostgres {
if originalDB.Type != databases.DatabaseTypePostgres {
return errors.New("database type not supported")
}
@@ -48,7 +56,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
backup.ID,
)
pg := restore.Postgresql
pg := restoringToDB.Postgresql
if pg == nil {
return fmt.Errorf("postgresql configuration is required for restore")
}
@@ -72,10 +80,12 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
"--verbose", // Add verbose output to help with debugging
"--clean", // Clean (drop) database objects before recreating them
"--if-exists", // Use IF EXISTS when dropping objects
"--no-owner",
"--no-owner", // Skip restoring ownership
"--no-acl", // Skip restoring access privileges (GRANT/REVOKE commands)
}
return uc.restoreFromStorage(
originalDB,
tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
@@ -87,17 +97,20 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
backup,
storage,
pg,
isExcludeExtensions,
)
}
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
database *databases.Database,
pgBin string,
args []string,
password string,
backup *backups.Backup,
storage *storages.Storage,
pgConfig *pgtypes.PostgresqlDatabase,
isExcludeExtensions bool,
) error {
uc.logger.Info(
"Restoring PostgreSQL backup from storage via temporary file",
@@ -105,6 +118,8 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
pgBin,
"args",
args,
"isExcludeExtensions",
isExcludeExtensions,
)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
@@ -161,10 +176,30 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
}
defer cleanupFunc()
// If excluding extensions, generate filtered TOC list and use it
if isExcludeExtensions {
tocListFile, err := uc.generateFilteredTocList(
ctx,
pgBin,
tempBackupFile,
pgpassFile,
pgConfig,
)
if err != nil {
return fmt.Errorf("failed to generate filtered TOC list: %w", err)
}
defer func() {
_ = os.Remove(tocListFile)
}()
// Add -L flag to use the filtered list
args = append(args, "-L", tocListFile)
}
// Add the temporary backup file as the last argument to pg_restore
args = append(args, tempBackupFile)
return uc.executePgRestore(ctx, pgBin, args, pgpassFile, pgConfig, backup)
return uc.executePgRestore(ctx, database, pgBin, args, pgpassFile, pgConfig)
}
// downloadBackupToTempFile downloads backup data from storage to a temporary file
@@ -199,18 +234,67 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
backup.ID,
"tempFile",
tempBackupFile,
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
backupReader, err := storage.GetFile(backup.ID)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := backupReader.Close(); err != nil {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
// Create a reader that handles decryption if needed
var backupReader io.Reader = rawReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
// Validate encryption metadata
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
cleanupFunc()
return "", nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get master key for decryption: %w", err)
}
// Decode salt and IV from base64
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption IV: %w", err)
}
// Create decryption reader
decryptReader, err := encryption.NewDecryptionReader(
rawReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create decryption reader: %w", err)
}
backupReader = decryptReader
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
}
// Create temporary backup file
tempFile, err := os.Create(tempBackupFile)
if err != nil {
@@ -240,11 +324,11 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
// executePgRestore executes the pg_restore command with proper environment setup
func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
ctx context.Context,
database *databases.Database,
pgBin string,
args []string,
pgpassFile string,
pgConfig *pgtypes.PostgresqlDatabase,
backup *backups.Backup,
) error {
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL restore command", "command", cmd.String())
@@ -293,7 +377,7 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
return fmt.Errorf("restore cancelled due to shutdown")
}
return uc.handlePgRestoreError(waitErr, stderrOutput, pgBin, args, backup, pgConfig)
return uc.handlePgRestoreError(database, waitErr, stderrOutput, pgBin, args, pgConfig)
}
return nil
@@ -319,7 +403,6 @@ func (uc *RestorePostgresqlBackupUsecase) setupPgRestoreEnvironment(
// Add encoding-related environment variables
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
shouldRequireSSL := pgConfig.IsHttps
@@ -341,11 +424,11 @@ func (uc *RestorePostgresqlBackupUsecase) setupPgRestoreEnvironment(
// handlePgRestoreError processes and formats pg_restore errors
func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
database *databases.Database,
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
backup *backups.Backup,
pgConfig *pgtypes.PostgresqlDatabase,
) error {
// Enhanced error handling for PostgreSQL connection and restore issues
@@ -416,8 +499,8 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
)
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
backupDbName := "unknown"
if backup.Database != nil && backup.Database.Postgresql != nil && backup.Database.Postgresql.Database != nil {
backupDbName = *backup.Database.Postgresql.Database
if database.Postgresql != nil && database.Postgresql.Database != nil {
backupDbName = *database.Postgresql.Database
}
targetDbName := "unknown"
@@ -444,7 +527,7 @@ func (uc *RestorePostgresqlBackupUsecase) copyWithShutdownCheck(
dst io.Writer,
src io.Reader,
) (int64, error) {
buf := make([]byte, 32*1024) // 32KB buffer
buf := make([]byte, 16*1024*1024) // 16MB buffer
var totalBytesWritten int64
for {
@@ -496,6 +579,75 @@ func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
}
// generateFilteredTocList generates a pg_restore TOC list file with extensions filtered out.
// This is used when isExcludeExtensions is true to skip CREATE EXTENSION statements.
func (uc *RestorePostgresqlBackupUsecase) generateFilteredTocList(
ctx context.Context,
pgBin string,
backupFile string,
pgpassFile string,
pgConfig *pgtypes.PostgresqlDatabase,
) (string, error) {
uc.logger.Info("Generating filtered TOC list to exclude extensions", "backupFile", backupFile)
// Run pg_restore -l to get the TOC list
listCmd := exec.CommandContext(ctx, pgBin, "-l", backupFile)
uc.setupPgRestoreEnvironment(listCmd, pgpassFile, pgConfig)
tocOutput, err := listCmd.Output()
if err != nil {
return "", fmt.Errorf("failed to generate TOC list: %w", err)
}
// Filter out EXTENSION-related lines (both CREATE EXTENSION and COMMENT ON EXTENSION)
var filteredLines []string
for line := range strings.SplitSeq(string(tocOutput), "\n") {
trimmedLine := strings.TrimSpace(line)
if trimmedLine == "" {
continue
}
upperLine := strings.ToUpper(trimmedLine)
// Skip lines that contain " EXTENSION " - this catches both:
// - CREATE EXTENSION entries: "3420; 0 0 EXTENSION - uuid-ossp"
// - COMMENT ON EXTENSION entries: "3462; 0 0 COMMENT - EXTENSION "uuid-ossp""
if strings.Contains(upperLine, " EXTENSION ") {
uc.logger.Info("Excluding extension-related entry from restore", "tocLine", trimmedLine)
continue
}
filteredLines = append(filteredLines, line)
}
// Write filtered TOC to temporary file
tocFile, err := os.CreateTemp("", "pg_restore_toc_*.list")
if err != nil {
return "", fmt.Errorf("failed to create TOC list file: %w", err)
}
tocFilePath := tocFile.Name()
filteredContent := strings.Join(filteredLines, "\n")
if _, err := tocFile.WriteString(filteredContent); err != nil {
_ = tocFile.Close()
_ = os.Remove(tocFilePath)
return "", fmt.Errorf("failed to write TOC list file: %w", err)
}
if err := tocFile.Close(); err != nil {
_ = os.Remove(tocFilePath)
return "", fmt.Errorf("failed to close TOC list file: %w", err)
}
uc.logger.Info("Generated filtered TOC list file",
"tocFile", tocFilePath,
"originalLines", len(strings.Split(string(tocOutput), "\n")),
"filteredLines", len(filteredLines),
)
return tocFilePath, nil
}
// createTempPgpassFile creates a temporary .pgpass file with the given password
func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
@@ -505,11 +657,15 @@ func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
return "", nil
}
escapedHost := tools.EscapePgpassField(pgConfig.Host)
escapedUsername := tools.EscapePgpassField(pgConfig.Username)
escapedPassword := tools.EscapePgpassField(password)
pgpassContent := fmt.Sprintf("%s:%d:*:%s:%s",
pgConfig.Host,
escapedHost,
pgConfig.Port,
pgConfig.Username,
password,
escapedUsername,
escapedPassword,
)
tempDir, err := os.MkdirTemp("", "pgpass")

View File

@@ -17,15 +17,21 @@ type RestoreBackupUsecase struct {
func (uc *RestoreBackupUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore models.Restore,
originalDB *databases.Database,
restoringToDB *databases.Database,
backup *backups.Backup,
storage *storages.Storage,
isExcludeExtensions bool,
) error {
if restore.Backup.Database.Type == databases.DatabaseTypePostgres {
if originalDB.Type == databases.DatabaseTypePostgres {
return uc.restorePostgresqlBackupUsecase.Execute(
originalDB,
restoringToDB,
backupConfig,
restore,
backup,
storage,
isExcludeExtensions,
)
}

View File

@@ -3,17 +3,25 @@ package storages
import (
"fmt"
"net/http"
"strings"
"testing"
audit_logs "postgresus-backend/internal/features/audit_logs"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
@@ -438,6 +446,535 @@ func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *test
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
"S3AccessKey should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
"S3SecretKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
assert.NoError(t, err)
assert.Equal(t, "original-access-key", accessKey)
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
assert.NoError(t, err)
assert.Equal(t, "original-secret-key", secretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
{
name: "NAS Storage",
storageType: StorageTypeNAS,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Test NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas.example.com",
Port: 445,
Share: "backups",
Username: "testuser",
Password: "original-password",
UseSSL: false,
Domain: "WORKGROUP",
Path: "/test",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Updated NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas2.example.com",
Port: 445,
Share: "backups2",
Username: "testuser2",
Password: "",
UseSSL: true,
Domain: "WORKGROUP2",
Path: "/test2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.NASStorage.Password)
},
},
{
name: "Azure Blob Storage (Connection String)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "original-connection-string",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
"ConnectionString should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
connectionString, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.ConnectionString,
)
assert.NoError(t, err)
assert.Equal(t, "original-connection-string", connectionString)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Azure Blob Storage (Account Key)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "testaccount",
AccountKey: "original-account-key",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "updatedaccount",
AccountKey: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
"AccountKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accountKey, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.AccountKey,
)
assert.NoError(t, err)
assert.Equal(t, "original-account-key", accountKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Google Drive Storage",
storageType: StorageTypeGoogleDrive,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Test Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Updated Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "updated-client-id",
ClientSecret: "",
TokenJSON: "",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
"ClientSecret should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
"TokenJSON should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
clientSecret, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.ClientSecret,
)
assert.NoError(t, err)
assert.Equal(t, "original-client-secret", clientSecret)
tokenJSON, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.TokenJSON,
)
assert.NoError(t, err)
assert.Equal(
t,
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
tokenJSON,
)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
},
},
{
name: "FTP Storage",
storageType: StorageTypeFTP,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeFTP,
Name: "Test FTP Storage",
FTPStorage: &ftp_storage.FTPStorage{
Host: "ftp.example.com",
Port: 21,
Username: "testuser",
Password: "original-password",
UseSSL: false,
Path: "/backups",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeFTP,
Name: "Updated FTP Storage",
FTPStorage: &ftp_storage.FTPStorage{
Host: "ftp2.example.com",
Port: 2121,
Username: "testuser2",
Password: "",
UseSSL: true,
Path: "/backups2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.FTPStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.FTPStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.FTPStorage.Password)
},
},
{
name: "SFTP Storage",
storageType: StorageTypeSFTP,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeSFTP,
Name: "Test SFTP Storage",
SFTPStorage: &sftp_storage.SFTPStorage{
Host: "sftp.example.com",
Port: 22,
Username: "testuser",
Password: "original-password",
PrivateKey: "original-private-key",
SkipHostKeyVerify: false,
Path: "/backups",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeSFTP,
Name: "Updated SFTP Storage",
SFTPStorage: &sftp_storage.SFTPStorage{
Host: "sftp2.example.com",
Port: 2222,
Username: "testuser2",
Password: "",
PrivateKey: "",
SkipHostKeyVerify: true,
Path: "/backups2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.SFTPStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.SFTPStorage.PrivateKey, "enc:"),
"PrivateKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.SFTPStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
privateKey, err := encryptor.Decrypt(storage.ID, storage.SFTPStorage.PrivateKey)
assert.NoError(t, err)
assert.Equal(t, "original-private-key", privateKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.SFTPStorage.Password)
assert.Equal(t, "", storage.SFTPStorage.PrivateKey)
},
},
{
name: "Rclone Storage",
storageType: StorageTypeRclone,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeRclone,
Name: "Test Rclone Storage",
RcloneStorage: &rclone_storage.RcloneStorage{
ConfigContent: "[myremote]\ntype = s3\nprovider = AWS\naccess_key_id = test\nsecret_access_key = secret\n",
RemotePath: "/backups",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeRclone,
Name: "Updated Rclone Storage",
RcloneStorage: &rclone_storage.RcloneStorage{
ConfigContent: "",
RemotePath: "/backups2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.RcloneStorage.ConfigContent, "enc:"),
"ConfigContent should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
configContent, err := encryptor.Decrypt(
storage.ID,
storage.RcloneStorage.ConfigContent,
)
assert.NoError(t, err)
assert.Equal(
t,
"[myremote]\ntype = s3\nprovider = AWS\naccess_key_id = test\nsecret_access_key = secret\n",
configContent,
)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.RcloneStorage.ConfigContent)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Verify sensitive data is encrypted in repository after creation
repository := &StorageRepository{}
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
tc.verifySensitiveData(t, storageFromDBAfterCreate)
// Phase 3: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
@@ -485,158 +1022,3 @@ func deleteStorage(
http.StatusOK,
)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "original-access-key", storage.S3Storage.S3AccessKey)
assert.Equal(t, "original-secret-key", storage.S3Storage.S3SecretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &StorageRepository{}
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}

View File

@@ -3,6 +3,7 @@ package storages
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
)
var storageRepository = &StorageRepository{}
@@ -10,6 +11,7 @@ var storageService = &StorageService{
storageRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var storageController = &StorageController{
storageService,

View File

@@ -7,4 +7,8 @@ const (
StorageTypeS3 StorageType = "S3"
StorageTypeGoogleDrive StorageType = "GOOGLE_DRIVE"
StorageTypeNAS StorageType = "NAS"
StorageTypeAzureBlob StorageType = "AZURE_BLOB"
StorageTypeFTP StorageType = "FTP"
StorageTypeSFTP StorageType = "SFTP"
StorageTypeRclone StorageType = "RCLONE"
)

View File

@@ -1,22 +1,32 @@
package storages
import (
"context"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
type StorageFileSaver interface {
SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error
SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error
GetFile(fileID uuid.UUID) (io.ReadCloser, error)
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
DeleteFile(fileID uuid.UUID) error
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
TestConnection() error
TestConnection(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -1,13 +1,19 @@
package storages
import (
"context"
"errors"
"io"
"log/slog"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -24,10 +30,20 @@ type Storage struct {
S3Storage *s3_storage.S3Storage `json:"s3Storage" gorm:"foreignKey:StorageID"`
GoogleDriveStorage *google_drive_storage.GoogleDriveStorage `json:"googleDriveStorage" gorm:"foreignKey:StorageID"`
NASStorage *nas_storage.NASStorage `json:"nasStorage" gorm:"foreignKey:StorageID"`
AzureBlobStorage *azure_blob_storage.AzureBlobStorage `json:"azureBlobStorage" gorm:"foreignKey:StorageID"`
FTPStorage *ftp_storage.FTPStorage `json:"ftpStorage" gorm:"foreignKey:StorageID"`
SFTPStorage *sftp_storage.SFTPStorage `json:"sftpStorage" gorm:"foreignKey:StorageID"`
RcloneStorage *rclone_storage.RcloneStorage `json:"rcloneStorage" gorm:"foreignKey:StorageID"`
}
func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
err := s.getSpecificStorage().SaveFile(logger, fileID, file)
func (s *Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileID, file)
if err != nil {
lastSaveError := err.Error()
s.LastSaveError = &lastSaveError
@@ -39,15 +55,18 @@ func (s *Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader
return nil
}
func (s *Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(fileID)
func (s *Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
return s.getSpecificStorage().GetFile(encryptor, fileID)
}
func (s *Storage) DeleteFile(fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(fileID)
func (s *Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
return s.getSpecificStorage().DeleteFile(encryptor, fileID)
}
func (s *Storage) Validate() error {
func (s *Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.Type == "" {
return errors.New("storage type is required")
}
@@ -56,17 +75,21 @@ func (s *Storage) Validate() error {
return errors.New("storage name is required")
}
return s.getSpecificStorage().Validate()
return s.getSpecificStorage().Validate(encryptor)
}
func (s *Storage) TestConnection() error {
return s.getSpecificStorage().TestConnection()
func (s *Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().TestConnection(encryptor)
}
func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return s.getSpecificStorage().EncryptSensitiveData(encryptor)
}
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type
@@ -88,6 +111,22 @@ func (s *Storage) Update(incoming *Storage) {
if s.NASStorage != nil && incoming.NASStorage != nil {
s.NASStorage.Update(incoming.NASStorage)
}
case StorageTypeAzureBlob:
if s.AzureBlobStorage != nil && incoming.AzureBlobStorage != nil {
s.AzureBlobStorage.Update(incoming.AzureBlobStorage)
}
case StorageTypeFTP:
if s.FTPStorage != nil && incoming.FTPStorage != nil {
s.FTPStorage.Update(incoming.FTPStorage)
}
case StorageTypeSFTP:
if s.SFTPStorage != nil && incoming.SFTPStorage != nil {
s.SFTPStorage.Update(incoming.SFTPStorage)
}
case StorageTypeRclone:
if s.RcloneStorage != nil && incoming.RcloneStorage != nil {
s.RcloneStorage.Update(incoming.RcloneStorage)
}
}
}
@@ -101,6 +140,14 @@ func (s *Storage) getSpecificStorage() StorageFileSaver {
return s.GoogleDriveStorage
case StorageTypeNAS:
return s.NASStorage
case StorageTypeAzureBlob:
return s.AzureBlobStorage
case StorageTypeFTP:
return s.FTPStorage
case StorageTypeSFTP:
return s.SFTPStorage
case StorageTypeRclone:
return s.RcloneStorage
default:
panic("invalid storage type: " + string(s.Type))
}

View File

@@ -8,15 +8,21 @@ import (
"os"
"path/filepath"
"postgresus-backend/internal/config"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"strconv"
"testing"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/google/uuid"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
@@ -32,6 +38,15 @@ type S3Container struct {
region string
}
type AzuriteContainer struct {
endpoint string
accountName string
accountKey string
containerNameKey string
containerNameStr string
connectionString string
}
func Test_Storage_BasicOperations(t *testing.T) {
ctx := context.Background()
@@ -41,6 +56,10 @@ func Test_Storage_BasicOperations(t *testing.T) {
s3Container, err := setupS3Container(ctx)
require.NoError(t, err, "Failed to setup S3 container")
// Setup Azurite connection
azuriteContainer, err := setupAzuriteContainer(ctx)
require.NoError(t, err, "Failed to setup Azurite container")
// Setup test file
testFilePath, err := setupTestFile()
require.NoError(t, err, "Failed to setup test file")
@@ -54,6 +73,22 @@ func Test_Storage_BasicOperations(t *testing.T) {
}
}
// Setup FTP port
ftpPort := 21
if portStr := config.GetEnv().TestFTPPort; portStr != "" {
if port, err := strconv.Atoi(portStr); err == nil {
ftpPort = port
}
}
// Setup SFTP port
sftpPort := 22
if portStr := config.GetEnv().TestSFTPPort; portStr != "" {
if port, err := strconv.Atoi(portStr); err == nil {
sftpPort = port
}
}
// Run tests
testCases := []struct {
name string
@@ -88,6 +123,64 @@ func Test_Storage_BasicOperations(t *testing.T) {
Path: "test-files",
},
},
{
name: "AzureBlobStorage_AccountKey",
storage: &azure_blob_storage.AzureBlobStorage{
StorageID: uuid.New(),
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: azuriteContainer.accountName,
AccountKey: azuriteContainer.accountKey,
ContainerName: azuriteContainer.containerNameKey,
Endpoint: azuriteContainer.endpoint,
},
},
{
name: "AzureBlobStorage_ConnectionString",
storage: &azure_blob_storage.AzureBlobStorage{
StorageID: uuid.New(),
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: azuriteContainer.connectionString,
ContainerName: azuriteContainer.containerNameStr,
},
},
{
name: "FTPStorage",
storage: &ftp_storage.FTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Port: ftpPort,
Username: "testuser",
Password: "testpassword",
UseSSL: false,
Path: "test-files",
},
},
{
name: "SFTPStorage",
storage: &sftp_storage.SFTPStorage{
StorageID: uuid.New(),
Host: "localhost",
Port: sftpPort,
Username: "testuser",
Password: "testpassword",
SkipHostKeyVerify: true,
Path: "upload",
},
},
{
name: "RcloneStorage",
storage: &rclone_storage.RcloneStorage{
StorageID: uuid.New(),
ConfigContent: fmt.Sprintf(`[minio]
type = s3
provider = Other
access_key_id = %s
secret_access_key = %s
endpoint = http://%s
acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoint),
RemotePath: s3Container.bucketName,
},
},
}
// Add Google Drive storage test only if environment variables are available
@@ -112,13 +205,15 @@ func Test_Storage_BasicOperations(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
encryptor := encryption.GetFieldEncryptor()
t.Run("Test_TestConnection_ConnectionSucceeds", func(t *testing.T) {
err := tc.storage.TestConnection()
err := tc.storage.TestConnection(encryptor)
assert.NoError(t, err, "TestConnection should succeed")
})
t.Run("Test_TestValidation_ValidationSucceeds", func(t *testing.T) {
err := tc.storage.Validate()
err := tc.storage.Validate(encryptor)
assert.NoError(t, err, "Validate should succeed")
})
@@ -128,10 +223,16 @@ func Test_Storage_BasicOperations(t *testing.T) {
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.NoError(t, err, "GetFile should succeed")
defer file.Close()
@@ -145,13 +246,19 @@ func Test_Storage_BasicOperations(t *testing.T) {
require.NoError(t, err, "Should be able to read test file")
fileID := uuid.New()
err = tc.storage.SaveFile(logger.GetLogger(), fileID, bytes.NewReader(fileData))
err = tc.storage.SaveFile(
context.Background(),
encryptor,
logger.GetLogger(),
fileID,
bytes.NewReader(fileData),
)
require.NoError(t, err, "SaveFile should succeed")
err = tc.storage.DeleteFile(fileID)
err = tc.storage.DeleteFile(encryptor, fileID)
assert.NoError(t, err, "DeleteFile should succeed")
file, err := tc.storage.GetFile(fileID)
file, err := tc.storage.GetFile(encryptor, fileID)
assert.Error(t, err, "GetFile should fail for non-existent file")
if file != nil {
file.Close()
@@ -161,7 +268,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
t.Run("Test_TestDeleteNonExistentFile_DoesNotError", func(t *testing.T) {
// Try to delete a non-existent file
nonExistentID := uuid.New()
err := tc.storage.DeleteFile(nonExistentID)
err := tc.storage.DeleteFile(encryptor, nonExistentID)
assert.NoError(t, err, "DeleteFile should not error for non-existent file")
})
})
@@ -190,7 +297,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
secretKey := "testpassword"
bucketName := "test-bucket"
region := "us-east-1"
endpoint := fmt.Sprintf("localhost:%s", env.TestMinioPort)
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
// Create MinIO client and ensure bucket exists
minioClient, err := minio.New(endpoint, &minio.Options{
@@ -230,8 +337,59 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
}, nil
}
func setupAzuriteContainer(ctx context.Context) (*AzuriteContainer, error) {
env := config.GetEnv()
accountName := "devstoreaccount1"
// this is real testing key for azurite, it's not a real key
accountKey := "Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw=="
serviceURL := fmt.Sprintf("http://127.0.0.1:%s/%s", env.TestAzuriteBlobPort, accountName)
containerNameKey := "test-container-key"
containerNameStr := "test-container-connstr"
// Build explicit connection string for Azurite
connectionString := fmt.Sprintf(
"DefaultEndpointsProtocol=http;AccountName=%s;AccountKey=%s;BlobEndpoint=http://127.0.0.1:%s/%s",
accountName,
accountKey,
env.TestAzuriteBlobPort,
accountName,
)
// Create client using connection string to set up containers
client, err := azblob.NewClientFromConnectionString(connectionString, nil)
if err != nil {
return nil, fmt.Errorf("failed to create azblob client: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Create container for account key auth
_, err = client.CreateContainer(ctx, containerNameKey, nil)
if err != nil {
// Container might already exist, that's okay
}
// Create container for connection string auth
_, err = client.CreateContainer(ctx, containerNameStr, nil)
if err != nil {
// Container might already exist, that's okay
}
return &AzuriteContainer{
endpoint: serviceURL,
accountName: accountName,
accountKey: accountKey,
containerNameKey: containerNameKey,
containerNameStr: containerNameStr,
connectionString: connectionString,
}, nil
}
func validateEnvVariables(t *testing.T) {
env := config.GetEnv()
assert.NotEmpty(t, env.TestMinioPort, "TEST_MINIO_PORT is empty")
assert.NotEmpty(t, env.TestAzuriteBlobPort, "TEST_AZURITE_BLOB_PORT is empty")
assert.NotEmpty(t, env.TestNASPort, "TEST_NAS_PORT is empty")
}

View File

@@ -0,0 +1,414 @@
package azure_blob_storage
import (
"bytes"
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
"github.com/google/uuid"
)
const (
azureConnectTimeout = 30 * time.Second
azureResponseTimeout = 30 * time.Second
azureIdleConnTimeout = 90 * time.Second
azureTLSHandshakeTimeout = 30 * time.Second
// Chunk size for block blob uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for Azure to confirm receipt.
azureChunkSize = 16 * 1024 * 1024
)
type readSeekCloser struct {
*bytes.Reader
}
func (r *readSeekCloser) Close() error {
return nil
}
type AuthMethod string
const (
AuthMethodConnectionString AuthMethod = "CONNECTION_STRING"
AuthMethodAccountKey AuthMethod = "ACCOUNT_KEY"
)
type AzureBlobStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
AuthMethod AuthMethod `json:"authMethod" gorm:"not null;type:text;column:auth_method"`
ConnectionString string `json:"connectionString" gorm:"type:text;column:connection_string"`
AccountName string `json:"accountName" gorm:"type:text;column:account_name"`
AccountKey string `json:"accountKey" gorm:"type:text;column:account_key"`
ContainerName string `json:"containerName" gorm:"not null;type:text;column:container_name"`
Endpoint string `json:"endpoint" gorm:"type:text;column:endpoint"`
Prefix string `json:"prefix" gorm:"type:text;column:prefix"`
}
func (s *AzureBlobStorage) TableName() string {
return "azure_blob_storages"
}
func (s *AzureBlobStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled before start: %w", ctx.Err())
default:
}
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
blockBlobClient := client.ServiceClient().
NewContainerClient(s.ContainerName).
NewBlockBlobClient(blobName)
var blockIDs []string
blockNumber := 0
buf := make([]byte, azureChunkSize)
for {
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled: %w", ctx.Err())
default:
}
n, readErr := io.ReadFull(file, buf)
if n == 0 && readErr == io.EOF {
break
}
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
return fmt.Errorf("read error: %w", readErr)
}
blockID := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%06d", blockNumber)))
_, err := blockBlobClient.StageBlock(
ctx,
blockID,
&readSeekCloser{bytes.NewReader(buf[:n])},
nil,
)
if err != nil {
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled: %w", ctx.Err())
default:
return fmt.Errorf("failed to stage block %d: %w", blockNumber, err)
}
}
blockIDs = append(blockIDs, blockID)
blockNumber++
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
break
}
}
if len(blockIDs) == 0 {
_, err = client.UploadStream(
ctx,
s.ContainerName,
blobName,
bytes.NewReader([]byte{}),
nil,
)
if err != nil {
return fmt.Errorf("failed to upload empty blob: %w", err)
}
return nil
}
_, err = blockBlobClient.CommitBlockList(ctx, blockIDs, &blockblob.CommitBlockListOptions{})
if err != nil {
return fmt.Errorf("failed to commit block list: %w", err)
}
return nil
}
func (s *AzureBlobStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
blobName := s.buildBlobName(fileID.String())
response, err := client.DownloadStream(
context.TODO(),
s.ContainerName,
blobName,
nil,
)
if err != nil {
return nil, fmt.Errorf("failed to download blob from Azure: %w", err)
}
return response.Body, nil
}
func (s *AzureBlobStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
blobName := s.buildBlobName(fileID.String())
_, err = client.DeleteBlob(
context.TODO(),
s.ContainerName,
blobName,
nil,
)
if err != nil {
var respErr *azcore.ResponseError
if errors.As(err, &respErr) && respErr.StatusCode == 404 {
return nil
}
return fmt.Errorf("failed to delete blob from Azure: %w", err)
}
return nil
}
func (s *AzureBlobStorage) Validate(encryptor encryption.FieldEncryptor) error {
if s.ContainerName == "" {
return errors.New("container name is required")
}
switch s.AuthMethod {
case AuthMethodConnectionString:
if s.ConnectionString == "" {
return errors.New(
"connection string is required when using CONNECTION_STRING auth method",
)
}
case AuthMethodAccountKey:
if s.AccountName == "" {
return errors.New("account name is required when using ACCOUNT_KEY auth method")
}
if s.AccountKey == "" {
return errors.New("account key is required when using ACCOUNT_KEY auth method")
}
default:
return fmt.Errorf("invalid auth method: %s", s.AuthMethod)
}
return nil
}
func (s *AzureBlobStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
containerClient := client.ServiceClient().NewContainerClient(s.ContainerName)
_, err = containerClient.GetProperties(ctx, nil)
if err != nil {
var respErr *azcore.ResponseError
if errors.As(err, &respErr) {
if respErr.StatusCode == 404 {
return fmt.Errorf("container '%s' does not exist", s.ContainerName)
}
}
if errors.Is(err, context.DeadlineExceeded) {
return errors.New("failed to connect to Azure Blob Storage. Please check params")
}
return fmt.Errorf("failed to connect to Azure Blob Storage: %w", err)
}
testBlobName := s.buildBlobName(uuid.New().String() + "-test")
testData := []byte("test connection")
_, err = client.UploadStream(
ctx,
s.ContainerName,
testBlobName,
bytes.NewReader(testData),
nil,
)
if err != nil {
return fmt.Errorf("failed to upload test blob to Azure: %w", err)
}
_, err = client.DeleteBlob(
ctx,
s.ContainerName,
testBlobName,
nil,
)
if err != nil {
return fmt.Errorf("failed to delete test blob from Azure: %w", err)
}
return nil
}
func (s *AzureBlobStorage) HideSensitiveData() {
s.ConnectionString = ""
s.AccountKey = ""
}
func (s *AzureBlobStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ConnectionString != "" {
s.ConnectionString, err = encryptor.Encrypt(s.StorageID, s.ConnectionString)
if err != nil {
return fmt.Errorf("failed to encrypt Azure connection string: %w", err)
}
}
if s.AccountKey != "" {
s.AccountKey, err = encryptor.Encrypt(s.StorageID, s.AccountKey)
if err != nil {
return fmt.Errorf("failed to encrypt Azure account key: %w", err)
}
}
return nil
}
func (s *AzureBlobStorage) Update(incoming *AzureBlobStorage) {
s.AuthMethod = incoming.AuthMethod
s.ContainerName = incoming.ContainerName
s.Endpoint = incoming.Endpoint
if incoming.ConnectionString != "" {
s.ConnectionString = incoming.ConnectionString
}
if incoming.AccountName != "" {
s.AccountName = incoming.AccountName
}
if incoming.AccountKey != "" {
s.AccountKey = incoming.AccountKey
}
}
func (s *AzureBlobStorage) buildBlobName(fileName string) string {
if s.Prefix == "" {
return fileName
}
prefix := s.Prefix
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
}
return prefix + fileName
}
func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azblob.Client, error) {
var client *azblob.Client
var err error
clientOptions := s.buildClientOptions()
switch s.AuthMethod {
case AuthMethodConnectionString:
connectionString, decryptErr := encryptor.Decrypt(s.StorageID, s.ConnectionString)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure connection string: %w", decryptErr)
}
client, err = azblob.NewClientFromConnectionString(connectionString, clientOptions)
if err != nil {
return nil, fmt.Errorf(
"failed to create Azure Blob client from connection string: %w",
err,
)
}
case AuthMethodAccountKey:
accountKey, decryptErr := encryptor.Decrypt(s.StorageID, s.AccountKey)
if decryptErr != nil {
return nil, fmt.Errorf("failed to decrypt Azure account key: %w", decryptErr)
}
accountURL := s.buildAccountURL()
credential, credErr := azblob.NewSharedKeyCredential(s.AccountName, accountKey)
if credErr != nil {
return nil, fmt.Errorf("failed to create Azure shared key credential: %w", credErr)
}
client, err = azblob.NewClientWithSharedKeyCredential(accountURL, credential, clientOptions)
if err != nil {
return nil, fmt.Errorf("failed to create Azure Blob client with shared key: %w", err)
}
default:
return nil, fmt.Errorf("unsupported auth method: %s", s.AuthMethod)
}
return client, nil
}
func (s *AzureBlobStorage) buildClientOptions() *azblob.ClientOptions {
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: azureConnectTimeout,
}).DialContext,
TLSHandshakeTimeout: azureTLSHandshakeTimeout,
ResponseHeaderTimeout: azureResponseTimeout,
IdleConnTimeout: azureIdleConnTimeout,
}
return &azblob.ClientOptions{
ClientOptions: azcore.ClientOptions{
Transport: &http.Client{Transport: transport},
Retry: policy.RetryOptions{
MaxRetries: 0,
},
},
}
}
func (s *AzureBlobStorage) buildAccountURL() string {
if s.Endpoint != "" {
endpoint := s.Endpoint
if !strings.HasPrefix(endpoint, "http://") && !strings.HasPrefix(endpoint, "https://") {
endpoint = "https://" + endpoint
}
return endpoint
}
return fmt.Sprintf("https://%s.blob.core.windows.net/", s.AccountName)
}

View File

@@ -0,0 +1,368 @@
package ftp_storage
import (
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
"github.com/google/uuid"
"github.com/jlaffaye/ftp"
)
const (
ftpConnectTimeout = 30 * time.Second
ftpTestConnectTimeout = 10 * time.Second
ftpChunkSize = 16 * 1024 * 1024
)
type FTPStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
Host string `json:"host" gorm:"not null;type:text;column:host"`
Port int `json:"port" gorm:"not null;default:21;column:port"`
Username string `json:"username" gorm:"not null;type:text;column:username"`
Password string `json:"password" gorm:"not null;type:text;column:password"`
Path string `json:"path" gorm:"type:text;column:path"`
UseSSL bool `json:"useSsl" gorm:"not null;default:false;column:use_ssl"`
SkipTLSVerify bool `json:"skipTlsVerify" gorm:"not null;default:false;column:skip_tls_verify"`
}
func (f *FTPStorage) TableName() string {
return "ftp_storages"
}
func (f *FTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.Info("Starting to save file to FTP storage", "fileId", fileID.String(), "host", f.Host)
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to FTP", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to connect to FTP: %w", err)
}
defer func() {
if quitErr := conn.Quit(); quitErr != nil {
logger.Error(
"Failed to close FTP connection",
"fileId",
fileID.String(),
"error",
quitErr,
)
}
}()
if f.Path != "" {
if err := f.ensureDirectory(conn, f.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"path",
f.Path,
"error",
err,
)
return fmt.Errorf("failed to ensure directory: %w", err)
}
}
filePath := f.getFilePath(fileID.String())
logger.Debug("Uploading file to FTP", "fileId", fileID.String(), "filePath", filePath)
ctxReader := &contextReader{ctx: ctx, reader: file}
err = conn.Stor(filePath, ctxReader)
if err != nil {
select {
case <-ctx.Done():
logger.Info("FTP upload cancelled", "fileId", fileID.String())
return ctx.Err()
default:
logger.Error("Failed to upload file to FTP", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to upload file to FTP: %w", err)
}
}
logger.Info(
"Successfully saved file to FTP storage",
"fileId",
fileID.String(),
"filePath",
filePath,
)
return nil
}
func (f *FTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to FTP: %w", err)
}
filePath := f.getFilePath(fileID.String())
resp, err := conn.Retr(filePath)
if err != nil {
_ = conn.Quit()
return nil, fmt.Errorf("failed to retrieve file from FTP: %w", err)
}
return &ftpFileReader{
response: resp,
conn: conn,
}, nil
}
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
conn, err := f.connect(encryptor, ftpConnectTimeout)
if err != nil {
return fmt.Errorf("failed to connect to FTP: %w", err)
}
defer func() {
_ = conn.Quit()
}()
filePath := f.getFilePath(fileID.String())
_, err = conn.FileSize(filePath)
if err != nil {
return nil
}
err = conn.Delete(filePath)
if err != nil {
return fmt.Errorf("failed to delete file from FTP: %w", err)
}
return nil
}
func (f *FTPStorage) Validate(encryptor encryption.FieldEncryptor) error {
if f.Host == "" {
return errors.New("FTP host is required")
}
if f.Username == "" {
return errors.New("FTP username is required")
}
if f.Password == "" {
return errors.New("FTP password is required")
}
if f.Port <= 0 || f.Port > 65535 {
return errors.New("FTP port must be between 1 and 65535")
}
return nil
}
func (f *FTPStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
ctx, cancel := context.WithTimeout(context.Background(), ftpTestConnectTimeout)
defer cancel()
conn, err := f.connectWithContext(ctx, encryptor, ftpTestConnectTimeout)
if err != nil {
return fmt.Errorf("failed to connect to FTP: %w", err)
}
defer func() {
_ = conn.Quit()
}()
if f.Path != "" {
if err := f.ensureDirectory(conn, f.Path); err != nil {
return fmt.Errorf("failed to access or create path '%s': %w", f.Path, err)
}
}
return nil
}
func (f *FTPStorage) HideSensitiveData() {
f.Password = ""
}
func (f *FTPStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if f.Password != "" {
encrypted, err := encryptor.Encrypt(f.StorageID, f.Password)
if err != nil {
return fmt.Errorf("failed to encrypt FTP password: %w", err)
}
f.Password = encrypted
}
return nil
}
func (f *FTPStorage) Update(incoming *FTPStorage) {
f.Host = incoming.Host
f.Port = incoming.Port
f.Username = incoming.Username
f.UseSSL = incoming.UseSSL
f.SkipTLSVerify = incoming.SkipTLSVerify
f.Path = incoming.Path
if incoming.Password != "" {
f.Password = incoming.Password
}
}
func (f *FTPStorage) connect(
encryptor encryption.FieldEncryptor,
timeout time.Duration,
) (*ftp.ServerConn, error) {
return f.connectWithContext(context.Background(), encryptor, timeout)
}
func (f *FTPStorage) connectWithContext(
ctx context.Context,
encryptor encryption.FieldEncryptor,
timeout time.Duration,
) (*ftp.ServerConn, error) {
password, err := encryptor.Decrypt(f.StorageID, f.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt FTP password: %w", err)
}
address := fmt.Sprintf("%s:%d", f.Host, f.Port)
dialCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
var conn *ftp.ServerConn
if f.UseSSL {
tlsConfig := &tls.Config{
ServerName: f.Host,
InsecureSkipVerify: f.SkipTLSVerify,
}
conn, err = ftp.Dial(address,
ftp.DialWithContext(dialCtx),
ftp.DialWithExplicitTLS(tlsConfig),
)
} else {
conn, err = ftp.Dial(address, ftp.DialWithContext(dialCtx))
}
if err != nil {
return nil, fmt.Errorf("failed to dial FTP server: %w", err)
}
err = conn.Login(f.Username, password)
if err != nil {
_ = conn.Quit()
return nil, fmt.Errorf("failed to login to FTP server: %w", err)
}
return conn, nil
}
func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
path = strings.TrimPrefix(path, "/")
path = strings.TrimSuffix(path, "/")
if path == "" {
return nil
}
parts := strings.Split(path, "/")
currentPath := ""
for _, part := range parts {
if part == "" || part == "." {
continue
}
if currentPath == "" {
currentPath = part
} else {
currentPath = currentPath + "/" + part
}
err := conn.ChangeDir(currentPath)
if err != nil {
err = conn.MakeDir(currentPath)
if err != nil {
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
}
}
err = conn.ChangeDirToParent()
if err != nil {
return fmt.Errorf("failed to change to parent directory: %w", err)
}
}
return nil
}
func (f *FTPStorage) getFilePath(filename string) string {
if f.Path == "" {
return filename
}
path := strings.TrimPrefix(f.Path, "/")
path = strings.TrimSuffix(path, "/")
return path + "/" + filename
}
type ftpFileReader struct {
response *ftp.Response
conn *ftp.ServerConn
}
func (r *ftpFileReader) Read(p []byte) (n int, err error) {
return r.response.Read(p)
}
func (r *ftpFileReader) Close() error {
var errs []error
if r.response != nil {
if err := r.response.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close response: %w", err))
}
}
if r.conn != nil {
if err := r.conn.Quit(); err != nil {
errs = append(errs, fmt.Errorf("failed to close connection: %w", err))
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}
type contextReader struct {
ctx context.Context
reader io.Reader
}
func (r *contextReader) Read(p []byte) (n int, err error) {
select {
case <-r.ctx.Done():
return 0, r.ctx.Err()
default:
return r.reader.Read(p)
}
}

View File

@@ -7,6 +7,9 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -15,9 +18,22 @@ import (
"golang.org/x/oauth2/google"
drive "google.golang.org/api/drive/v3"
"google.golang.org/api/googleapi"
"google.golang.org/api/option"
)
const (
gdConnectTimeout = 30 * time.Second
gdResponseTimeout = 30 * time.Second
gdIdleConnTimeout = 90 * time.Second
gdTLSHandshakeTimeout = 30 * time.Second
// Chunk size for Google Drive resumable uploads - 16MB provides good balance
// between memory usage and upload efficiency. Google Drive requires chunks
// to be multiples of 256KB for resumable uploads.
gdChunkSize = 16 * 1024 * 1024
)
type GoogleDriveStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
ClientID string `json:"clientId" gorm:"not null;type:text;column:client_id"`
@@ -30,30 +46,44 @@ func (s *GoogleDriveStorage) TableName() string {
}
func (s *GoogleDriveStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
ctx := context.Background()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
filename := fileID.String()
// Ensure the postgresus_backups folder exists
folderID, err := s.ensureBackupsFolderExists(ctx, driveService)
if err != nil {
return fmt.Errorf("failed to create/find backups folder: %w", err)
}
// Delete any previous copy so we keep at most one object per logical file.
_ = s.deleteByName(ctx, driveService, filename, folderID) // ignore "not found"
_ = s.deleteByName(ctx, driveService, filename, folderID)
fileMeta := &drive.File{
Name: filename,
Parents: []string{folderID},
}
_, err = driveService.Files.Create(fileMeta).Media(file).Context(ctx).Do()
backpressureReader := &backpressureReader{
reader: file,
ctx: ctx,
chunkSize: gdChunkSize,
buf: make([]byte, gdChunkSize),
}
_, err = driveService.Files.Create(fileMeta).
Media(backpressureReader, googleapi.ChunkSize(gdChunkSize)).
Context(ctx).
Do()
if err != nil {
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled: %w", ctx.Err())
default:
}
return fmt.Errorf("failed to upload file to Google Drive: %w", err)
}
@@ -68,34 +98,95 @@ func (s *GoogleDriveStorage) SaveFile(
})
}
func (s *GoogleDriveStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
type backpressureReader struct {
reader io.Reader
ctx context.Context
chunkSize int
buf []byte
bufStart int
bufEnd int
totalBytes int64
chunkCount int
}
func (r *backpressureReader) Read(p []byte) (n int, err error) {
select {
case <-r.ctx.Done():
return 0, r.ctx.Err()
default:
}
if r.bufStart >= r.bufEnd {
r.chunkCount++
bytesRead, readErr := io.ReadFull(r.reader, r.buf)
if bytesRead > 0 {
r.bufStart = 0
r.bufEnd = bytesRead
}
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
return 0, readErr
}
if bytesRead == 0 && readErr == io.EOF {
return 0, io.EOF
}
}
n = copy(p, r.buf[r.bufStart:r.bufEnd])
r.bufStart += n
r.totalBytes += int64(n)
if r.bufStart >= r.bufEnd {
select {
case <-r.ctx.Done():
return n, r.ctx.Err()
default:
}
}
return n, nil
}
func (s *GoogleDriveStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
var result io.ReadCloser
err := s.withRetryOnAuth(func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
return fmt.Errorf("failed to find backups folder: %w", err)
}
err := s.withRetryOnAuth(
context.Background(),
encryptor,
func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
return fmt.Errorf("failed to find backups folder: %w", err)
}
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
if err != nil {
return err
}
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
if err != nil {
return err
}
resp, err := driveService.Files.Get(fileIDGoogle).Download()
if err != nil {
return fmt.Errorf("failed to download file from Google Drive: %w", err)
}
resp, err := driveService.Files.Get(fileIDGoogle).Download()
if err != nil {
return fmt.Errorf("failed to download file from Google Drive: %w", err)
}
result = resp.Body
return nil
})
result = resp.Body
return nil
},
)
return result, err
}
func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
ctx := context.Background()
func (s *GoogleDriveStorage) DeleteFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) error {
ctx := context.Background()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
folderID, err := s.findBackupsFolder(driveService)
if err != nil {
return fmt.Errorf("failed to find backups folder: %w", err)
@@ -105,7 +196,7 @@ func (s *GoogleDriveStorage) DeleteFile(fileID uuid.UUID) error {
})
}
func (s *GoogleDriveStorage) Validate() error {
func (s *GoogleDriveStorage) Validate(encryptor encryption.FieldEncryptor) error {
switch {
case s.ClientID == "":
return errors.New("client ID is required")
@@ -115,7 +206,12 @@ func (s *GoogleDriveStorage) Validate() error {
return errors.New("token JSON is required")
}
// Also validate that the token JSON contains a refresh token
// Skip JSON validation if token is already encrypted
if strings.HasPrefix(s.TokenJSON, "enc:") {
return nil
}
// Validate that the token JSON contains a refresh token
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON format: %w", err)
@@ -128,9 +224,9 @@ func (s *GoogleDriveStorage) Validate() error {
return nil
}
func (s *GoogleDriveStorage) TestConnection() error {
return s.withRetryOnAuth(func(driveService *drive.Service) error {
ctx := context.Background()
func (s *GoogleDriveStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
ctx := context.Background()
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
testFilename := "test-connection-" + uuid.New().String()
testData := []byte("test")
@@ -196,6 +292,26 @@ func (s *GoogleDriveStorage) HideSensitiveData() {
s.TokenJSON = ""
}
func (s *GoogleDriveStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.ClientSecret != "" {
s.ClientSecret, err = encryptor.Encrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive client secret: %w", err)
}
}
if s.TokenJSON != "" {
s.TokenJSON, err = encryptor.Encrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to encrypt Google Drive token JSON: %w", err)
}
}
return nil
}
func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
s.ClientID = incoming.ClientID
@@ -209,18 +325,34 @@ func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
}
// withRetryOnAuth executes the provided function with retry logic for authentication errors
func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) error {
driveService, err := s.getDriveService()
func (s *GoogleDriveStorage) withRetryOnAuth(
ctx context.Context,
encryptor encryption.FieldEncryptor,
fn func(*drive.Service) error,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
driveService, err := s.getDriveService(encryptor)
if err != nil {
return err
}
err = fn(driveService)
if err != nil && s.isAuthError(err) {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Try to refresh token and retry once
fmt.Printf("Google Drive auth error detected, attempting token refresh: %v\n", err)
if refreshErr := s.refreshToken(); refreshErr != nil {
if refreshErr := s.refreshToken(encryptor); refreshErr != nil {
// If refresh fails, return a more helpful error message
if strings.Contains(refreshErr.Error(), "invalid_grant") ||
strings.Contains(refreshErr.Error(), "refresh token") {
@@ -237,7 +369,7 @@ func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) erro
fmt.Printf("Token refresh successful, retrying operation\n")
// Get new service with refreshed token
driveService, err = s.getDriveService()
driveService, err = s.getDriveService(encryptor)
if err != nil {
return fmt.Errorf("failed to create service after token refresh: %w", err)
}
@@ -268,13 +400,24 @@ func (s *GoogleDriveStorage) isAuthError(err error) bool {
}
// refreshToken refreshes the OAuth2 token and updates the TokenJSON field
func (s *GoogleDriveStorage) refreshToken() error {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) refreshToken(encryptor encryption.FieldEncryptor) error {
if err := s.Validate(encryptor); err != nil {
return err
}
// Decrypt credentials before use
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return fmt.Errorf("invalid token JSON: %w", err)
}
@@ -289,12 +432,12 @@ func (s *GoogleDriveStorage) refreshToken() error {
token.Expiry)
// Debug: Print the full token JSON structure (sensitive data masked)
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(s.TokenJSON))
fmt.Printf("Original token JSON structure: %s\n", maskSensitiveData(tokenJSON))
ctx := context.Background()
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}
@@ -330,7 +473,7 @@ func (s *GoogleDriveStorage) refreshToken() error {
newToken.RefreshToken = token.RefreshToken
}
// Update the stored token JSON
// Update the stored token JSON (keep as plaintext in memory, encryption happens on save)
newTokenJSON, err := json.Marshal(newToken)
if err != nil {
return fmt.Errorf("failed to marshal refreshed token: %w", err)
@@ -368,13 +511,25 @@ func truncateString(s string, maxLen int) string {
return s[:maxLen]
}
func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
if err := s.Validate(); err != nil {
func (s *GoogleDriveStorage) getDriveService(
encryptor encryption.FieldEncryptor,
) (*drive.Service, error) {
if err := s.Validate(encryptor); err != nil {
return nil, err
}
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
}
tokenJSON, err := encryptor.Decrypt(s.StorageID, s.TokenJSON)
if err != nil {
return nil, fmt.Errorf("failed to decrypt Google Drive token JSON: %w", err)
}
var token oauth2.Token
if err := json.Unmarshal([]byte(s.TokenJSON), &token); err != nil {
if err := json.Unmarshal([]byte(tokenJSON), &token); err != nil {
return nil, fmt.Errorf("invalid token JSON: %w", err)
}
@@ -382,23 +537,23 @@ func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
cfg := &oauth2.Config{
ClientID: s.ClientID,
ClientSecret: s.ClientSecret,
ClientSecret: clientSecret,
Endpoint: google.Endpoint,
Scopes: []string{"https://www.googleapis.com/auth/drive.file"},
}
tokenSource := cfg.TokenSource(ctx, &token)
// Force token validation to ensure we're using the current token
currentToken, err := tokenSource.Token()
if err != nil {
return nil, fmt.Errorf("failed to get current token: %w", err)
}
// Create a new token source with the validated token
validatedTokenSource := oauth2.StaticTokenSource(currentToken)
driveService, err := drive.NewService(ctx, option.WithTokenSource(validatedTokenSource))
httpClient := s.buildHTTPClient(validatedTokenSource)
driveService, err := drive.NewService(ctx, option.WithHTTPClient(httpClient))
if err != nil {
return nil, fmt.Errorf("unable to create Drive client: %w", err)
}
@@ -406,6 +561,24 @@ func (s *GoogleDriveStorage) getDriveService() (*drive.Service, error) {
return driveService, nil
}
func (s *GoogleDriveStorage) buildHTTPClient(tokenSource oauth2.TokenSource) *http.Client {
transport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: gdConnectTimeout,
}).DialContext,
TLSHandshakeTimeout: gdTLSHandshakeTimeout,
ResponseHeaderTimeout: gdResponseTimeout,
IdleConnTimeout: gdIdleConnTimeout,
}
return &http.Client{
Transport: &oauth2.Transport{
Source: tokenSource,
Base: transport,
},
}
}
func (s *GoogleDriveStorage) lookupFileID(
driveService *drive.Service,
name string,

View File

@@ -1,17 +1,26 @@
package local_storage
import (
"context"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"postgresus-backend/internal/config"
"postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"github.com/google/uuid"
)
const (
// Chunk size for local storage writes - 8MB per buffer with double-buffering
// allows overlapped I/O while keeping total memory under 32MB.
// Two 8MB buffers = 16MB for local storage, plus 8MB for pg_dump buffer = ~25MB total.
localChunkSize = 8 * 1024 * 1024
)
// LocalStorage uses ./postgresus_local_backups folder as a
// directory for backups and ./postgresus_local_temp folder as a
// directory for temp files
@@ -23,7 +32,19 @@ func (l *LocalStorage) TableName() string {
return "local_storages"
}
func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (l *LocalStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
err := files_utils.EnsureDirectories([]string{
@@ -54,7 +75,7 @@ func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.R
}()
logger.Debug("Copying file data to temp file", "fileId", fileID.String())
_, err = io.Copy(tempFile, file)
_, err = copyWithContext(ctx, tempFile, file)
if err != nil {
logger.Error("Failed to write to temp file", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to write to temp file: %w", err)
@@ -107,7 +128,10 @@ func (l *LocalStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.R
return nil
}
func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
func (l *LocalStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -122,7 +146,7 @@ func (l *LocalStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return file, nil
}
func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
func (l *LocalStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
filePath := filepath.Join(config.GetEnv().DataFolder, fileID.String())
if _, err := os.Stat(filePath); os.IsNotExist(err) {
@@ -136,11 +160,11 @@ func (l *LocalStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (l *LocalStorage) Validate() error {
func (l *LocalStorage) Validate(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
testFile := filepath.Join(config.GetEnv().TempFolder, "test_connection")
f, err := os.Create(testFile)
if err != nil {
@@ -160,5 +184,41 @@ func (l *LocalStorage) TestConnection() error {
func (l *LocalStorage) HideSensitiveData() {
}
func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return nil
}
func (l *LocalStorage) Update(incoming *LocalStorage) {
}
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
buf := make([]byte, localChunkSize)
var written int64
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, readErr := src.Read(buf)
if nr > 0 {
nw, writeErr := dst.Write(buf[:nr])
written += int64(nw)
if writeErr != nil {
return written, writeErr
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if readErr == io.EOF {
return written, nil
}
if readErr != nil {
return written, readErr
}
}
}

View File

@@ -1,6 +1,7 @@
package nas_storage
import (
"context"
"crypto/tls"
"errors"
"fmt"
@@ -8,6 +9,7 @@ import (
"log/slog"
"net"
"path/filepath"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -15,6 +17,13 @@ import (
"github.com/hirochachacha/go-smb2"
)
const (
// Chunk size for NAS uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for NAS to confirm receipt.
nasChunkSize = 16 * 1024 * 1024
)
type NASStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
Host string `json:"host" gorm:"not null;type:text;column:host"`
@@ -31,10 +40,22 @@ func (n *NASStorage) TableName() string {
return "nas_storages"
}
func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
func (n *NASStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
session, err := n.createSession()
session, err := n.createSessionWithContext(ctx, encryptor)
if err != nil {
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to create NAS session: %w", err)
@@ -115,7 +136,7 @@ func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Rea
}()
logger.Debug("Copying file data to NAS", "fileId", fileID.String())
_, err = io.Copy(nasFile, file)
_, err = copyWithContext(ctx, nasFile, file)
if err != nil {
logger.Error("Failed to write file to NAS", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to write file to NAS: %w", err)
@@ -131,8 +152,11 @@ func (n *NASStorage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Rea
return nil
}
func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
session, err := n.createSession()
func (n *NASStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
session, err := n.createSession(encryptor)
if err != nil {
return nil, fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -168,8 +192,8 @@ func (n *NASStorage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
}, nil
}
func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
session, err := n.createSession()
func (n *NASStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to create NAS session: %w", err)
}
@@ -202,7 +226,7 @@ func (n *NASStorage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (n *NASStorage) Validate() error {
func (n *NASStorage) Validate(encryptor encryption.FieldEncryptor) error {
if n.Host == "" {
return errors.New("NAS host is required")
}
@@ -219,12 +243,11 @@ func (n *NASStorage) Validate() error {
return errors.New("NAS port must be between 1 and 65535")
}
// Test the configuration by creating a session
return n.TestConnection()
return nil
}
func (n *NASStorage) TestConnection() error {
session, err := n.createSession()
func (n *NASStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
session, err := n.createSession(encryptor)
if err != nil {
return fmt.Errorf("failed to connect to NAS: %w", err)
}
@@ -255,6 +278,18 @@ func (n *NASStorage) HideSensitiveData() {
n.Password = ""
}
func (n *NASStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.Password != "" {
encrypted, err := encryptor.Encrypt(n.StorageID, n.Password)
if err != nil {
return fmt.Errorf("failed to encrypt NAS password: %w", err)
}
n.Password = encrypted
}
return nil
}
func (n *NASStorage) Update(incoming *NASStorage) {
n.Host = incoming.Host
n.Port = incoming.Port
@@ -269,23 +304,33 @@ func (n *NASStorage) Update(incoming *NASStorage) {
}
}
func (n *NASStorage) createSession() (*smb2.Session, error) {
// Create connection with timeout
conn, err := n.createConnection()
func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.Session, error) {
return n.createSessionWithContext(context.Background(), encryptor)
}
func (n *NASStorage) createSessionWithContext(
ctx context.Context,
encryptor encryption.FieldEncryptor,
) (*smb2.Session, error) {
conn, err := n.createConnectionWithContext(ctx)
if err != nil {
return nil, err
}
// Create SMB2 dialer
password, err := encryptor.Decrypt(n.StorageID, n.Password)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to decrypt NAS password: %w", err)
}
d := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: n.Username,
Password: n.Password,
Password: password,
Domain: n.Domain,
},
}
// Create session
session, err := d.Dial(conn)
if err != nil {
_ = conn.Close()
@@ -295,34 +340,30 @@ func (n *NASStorage) createSession() (*smb2.Session, error) {
return session, nil
}
func (n *NASStorage) createConnection() (net.Conn, error) {
func (n *NASStorage) createConnectionWithContext(ctx context.Context) (net.Conn, error) {
address := net.JoinHostPort(n.Host, fmt.Sprintf("%d", n.Port))
// Create connection with timeout
dialer := &net.Dialer{
Timeout: 10 * time.Second,
Timeout: 30 * time.Second,
}
if n.UseSSL {
// Use TLS connection
tlsConfig := &tls.Config{
ServerName: n.Host,
InsecureSkipVerify: false, // Change to true if you want to skip cert verification
InsecureSkipVerify: false,
}
conn, err := tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
if err != nil {
return nil, fmt.Errorf("failed to create SSL connection to %s: %w", address, err)
}
return conn, nil
} else {
// Use regular TCP connection
conn, err := dialer.Dial("tcp", address)
if err != nil {
return nil, fmt.Errorf("failed to create connection to %s: %w", address, err)
}
return conn, nil
}
conn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
return nil, fmt.Errorf("failed to create connection to %s: %w", address, err)
}
return conn, nil
}
func (n *NASStorage) ensureDirectory(fs *smb2.Share, path string) error {
@@ -417,3 +458,71 @@ func (r *nasFileReader) Close() error {
return nil
}
type writeResult struct {
bytesWritten int
writeErr error
}
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
buf := make([]byte, nasChunkSize)
var written int64
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, readErr := io.ReadFull(src, buf)
if nr == 0 && readErr == io.EOF {
break
}
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
return written, readErr
}
writeResultCh := make(chan writeResult, 1)
go func() {
nw, writeErr := dst.Write(buf[0:nr])
writeResultCh <- writeResult{nw, writeErr}
}()
var nw int
var writeErr error
select {
case <-ctx.Done():
return written, ctx.Err()
case result := <-writeResultCh:
nw = result.bytesWritten
writeErr = result.writeErr
}
if nw < 0 || nr < nw {
nw = 0
if writeErr == nil {
writeErr = errors.New("invalid write result")
}
}
if writeErr != nil {
return written, writeErr
}
if nr != nw {
return written, io.ErrShortWrite
}
written += int64(nw)
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
break
}
}
return written, nil
}

View File

@@ -0,0 +1,293 @@
package rclone_storage
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"log/slog"
"postgresus-backend/internal/util/encryption"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/rclone/rclone/fs"
"github.com/rclone/rclone/fs/config"
"github.com/rclone/rclone/fs/operations"
_ "github.com/rclone/rclone/backend/all"
)
const (
rcloneOperationTimeout = 30 * time.Second
)
var rcloneConfigMu sync.Mutex
type RcloneStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
ConfigContent string `json:"configContent" gorm:"not null;type:text;column:config_content"`
RemotePath string `json:"remotePath" gorm:"type:text;column:remote_path"`
}
func (r *RcloneStorage) TableName() string {
return "rclone_storages"
}
func (r *RcloneStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.Info("Starting to save file to rclone storage", "fileId", fileID.String())
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
logger.Error("Failed to create rclone filesystem", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
logger.Debug("Uploading file via rclone", "fileId", fileID.String(), "filePath", filePath)
_, err = operations.Rcat(ctx, remoteFs, filePath, io.NopCloser(file), time.Now().UTC(), nil)
if err != nil {
select {
case <-ctx.Done():
logger.Info("Rclone upload cancelled", "fileId", fileID.String())
return ctx.Err()
default:
logger.Error(
"Failed to upload file via rclone",
"fileId",
fileID.String(),
"error",
err,
)
return fmt.Errorf("failed to upload file via rclone: %w", err)
}
}
logger.Info(
"Successfully saved file to rclone storage",
"fileId",
fileID.String(),
"filePath",
filePath,
)
return nil
}
func (r *RcloneStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
ctx := context.Background()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
return nil, fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {
return nil, fmt.Errorf("failed to get object from rclone: %w", err)
}
reader, err := obj.Open(ctx)
if err != nil {
return nil, fmt.Errorf("failed to open object from rclone: %w", err)
}
return reader, nil
}
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
ctx := context.Background()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
filePath := r.getFilePath(fileID.String())
obj, err := remoteFs.NewObject(ctx, filePath)
if err != nil {
return nil
}
err = obj.Remove(ctx)
if err != nil {
return fmt.Errorf("failed to delete file from rclone: %w", err)
}
return nil
}
func (r *RcloneStorage) Validate(encryptor encryption.FieldEncryptor) error {
if r.ConfigContent == "" {
return errors.New("rclone config content is required")
}
return nil
}
func (r *RcloneStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
ctx, cancel := context.WithTimeout(context.Background(), rcloneOperationTimeout)
defer cancel()
remoteFs, err := r.getFs(ctx, encryptor)
if err != nil {
return fmt.Errorf("failed to create rclone filesystem: %w", err)
}
testFileID := uuid.New().String() + "-test"
testFilePath := r.getFilePath(testFileID)
testData := strings.NewReader("test connection")
_, err = operations.Rcat(
ctx,
remoteFs,
testFilePath,
io.NopCloser(testData),
time.Now().UTC(),
nil,
)
if err != nil {
return fmt.Errorf("failed to upload test file via rclone: %w", err)
}
obj, err := remoteFs.NewObject(ctx, testFilePath)
if err != nil {
return fmt.Errorf("failed to get test file from rclone: %w", err)
}
err = obj.Remove(ctx)
if err != nil {
return fmt.Errorf("failed to delete test file from rclone: %w", err)
}
return nil
}
func (r *RcloneStorage) HideSensitiveData() {
r.ConfigContent = ""
}
func (r *RcloneStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if r.ConfigContent != "" {
encrypted, err := encryptor.Encrypt(r.StorageID, r.ConfigContent)
if err != nil {
return fmt.Errorf("failed to encrypt rclone config content: %w", err)
}
r.ConfigContent = encrypted
}
return nil
}
func (r *RcloneStorage) Update(incoming *RcloneStorage) {
r.RemotePath = incoming.RemotePath
if incoming.ConfigContent != "" {
r.ConfigContent = incoming.ConfigContent
}
}
func (r *RcloneStorage) getFs(
ctx context.Context,
encryptor encryption.FieldEncryptor,
) (fs.Fs, error) {
configContent, err := encryptor.Decrypt(r.StorageID, r.ConfigContent)
if err != nil {
return nil, fmt.Errorf("failed to decrypt rclone config content: %w", err)
}
rcloneConfigMu.Lock()
defer rcloneConfigMu.Unlock()
parsedConfig, err := parseConfigContent(configContent)
if err != nil {
return nil, fmt.Errorf("failed to parse rclone config: %w", err)
}
if len(parsedConfig) == 0 {
return nil, errors.New("rclone config must contain at least one remote section")
}
var remoteName string
for section, values := range parsedConfig {
remoteName = section
for key, value := range values {
config.FileSetValue(section, key, value)
}
}
remotePath := remoteName + ":"
if r.RemotePath != "" {
remotePath = remoteName + ":" + strings.TrimPrefix(r.RemotePath, "/")
}
remoteFs, err := fs.NewFs(ctx, remotePath)
if err != nil {
return nil, fmt.Errorf(
"failed to create rclone filesystem for remote '%s': %w",
remoteName,
err,
)
}
return remoteFs, nil
}
func (r *RcloneStorage) getFilePath(filename string) string {
return filename
}
func parseConfigContent(content string) (map[string]map[string]string, error) {
sections := make(map[string]map[string]string)
var currentSection string
scanner := bufio.NewScanner(strings.NewReader(content))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
continue
}
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
currentSection = strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[")
if sections[currentSection] == nil {
sections[currentSection] = make(map[string]string)
}
continue
}
if currentSection != "" && strings.Contains(line, "=") {
parts := strings.SplitN(line, "=", 2)
key := strings.TrimSpace(parts[0])
value := ""
if len(parts) > 1 {
value = strings.TrimSpace(parts[1])
}
sections[currentSection][key] = value
}
}
return sections, scanner.Err()
}

View File

@@ -3,10 +3,14 @@ package s3_storage
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
@@ -15,6 +19,18 @@ import (
"github.com/minio/minio-go/v7/pkg/credentials"
)
const (
s3ConnectTimeout = 30 * time.Second
s3ResponseTimeout = 30 * time.Second
s3IdleConnTimeout = 90 * time.Second
s3TLSHandshakeTimeout = 30 * time.Second
// Chunk size for multipart uploads - 16MB provides good balance between
// memory usage and upload efficiency. This creates backpressure to pg_dump
// by only reading one chunk at a time and waiting for S3 to confirm receipt.
multipartChunkSize = 16 * 1024 * 1024
)
type S3Storage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
S3Bucket string `json:"s3Bucket" gorm:"not null;type:text;column:s3_bucket"`
@@ -22,44 +38,154 @@ type S3Storage struct {
S3AccessKey string `json:"s3AccessKey" gorm:"not null;type:text;column:s3_access_key"`
S3SecretKey string `json:"s3SecretKey" gorm:"not null;type:text;column:s3_secret_key"`
S3Endpoint string `json:"s3Endpoint" gorm:"type:text;column:s3_endpoint"`
S3Prefix string `json:"s3Prefix" gorm:"type:text;column:s3_prefix"`
S3UseVirtualHostedStyle bool `json:"s3UseVirtualHostedStyle" gorm:"default:false;column:s3_use_virtual_hosted_style"`
SkipTLSVerify bool `json:"skipTLSVerify" gorm:"default:false;column:skip_tls_verify"`
}
func (s *S3Storage) TableName() string {
return "s3_storages"
}
func (s *S3Storage) SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error {
client, err := s.getClient()
func (s *S3Storage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled before start: %w", ctx.Err())
default:
}
coreClient, err := s.getCoreClient(encryptor)
if err != nil {
return err
}
// Upload the file using MinIO client with streaming (size = -1 for unknown size)
_, err = client.PutObject(
context.TODO(),
objectKey := s.buildObjectKey(fileID.String())
uploadID, err := coreClient.NewMultipartUpload(
ctx,
s.S3Bucket,
fileID.String(),
file,
-1,
objectKey,
minio.PutObjectOptions{},
)
if err != nil {
return fmt.Errorf("failed to upload file to S3: %w", err)
return fmt.Errorf("failed to initiate multipart upload: %w", err)
}
var parts []minio.CompletePart
partNumber := 1
buf := make([]byte, multipartChunkSize)
for {
select {
case <-ctx.Done():
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
return fmt.Errorf("upload cancelled: %w", ctx.Err())
default:
}
n, readErr := io.ReadFull(file, buf)
if n == 0 && readErr == io.EOF {
break
}
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
return fmt.Errorf("read error: %w", readErr)
}
part, err := coreClient.PutObjectPart(
ctx,
s.S3Bucket,
objectKey,
uploadID,
partNumber,
bytes.NewReader(buf[:n]),
int64(n),
minio.PutObjectPartOptions{},
)
if err != nil {
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
select {
case <-ctx.Done():
return fmt.Errorf("upload cancelled: %w", ctx.Err())
default:
return fmt.Errorf("failed to upload part %d: %w", partNumber, err)
}
}
parts = append(parts, minio.CompletePart{
PartNumber: partNumber,
ETag: part.ETag,
})
partNumber++
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
break
}
}
if len(parts) == 0 {
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
client, err := s.getClient(encryptor)
if err != nil {
return err
}
_, err = client.PutObject(
ctx,
s.S3Bucket,
objectKey,
bytes.NewReader([]byte{}),
0,
minio.PutObjectOptions{},
)
if err != nil {
return fmt.Errorf("failed to upload empty file: %w", err)
}
return nil
}
_, err = coreClient.CompleteMultipartUpload(
ctx,
s.S3Bucket,
objectKey,
uploadID,
parts,
minio.PutObjectOptions{},
)
if err != nil {
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
return fmt.Errorf("failed to complete multipart upload: %w", err)
}
return nil
}
func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
client, err := s.getClient()
func (s *S3Storage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, err := s.getClient(encryptor)
if err != nil {
return nil, err
}
objectKey := s.buildObjectKey(fileID.String())
object, err := client.GetObject(
context.TODO(),
s.S3Bucket,
fileID.String(),
objectKey,
minio.GetObjectOptions{},
)
if err != nil {
@@ -84,17 +210,19 @@ func (s *S3Storage) GetFile(fileID uuid.UUID) (io.ReadCloser, error) {
return object, nil
}
func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
client, err := s.getClient()
func (s *S3Storage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
objectKey := s.buildObjectKey(fileID.String())
// Delete the object using MinIO client
err = client.RemoveObject(
context.TODO(),
s.S3Bucket,
fileID.String(),
objectKey,
minio.RemoveObjectOptions{},
)
if err != nil {
@@ -104,7 +232,7 @@ func (s *S3Storage) DeleteFile(fileID uuid.UUID) error {
return nil
}
func (s *S3Storage) Validate() error {
func (s *S3Storage) Validate(encryptor encryption.FieldEncryptor) error {
if s.S3Bucket == "" {
return errors.New("S3 bucket is required")
}
@@ -115,17 +243,11 @@ func (s *S3Storage) Validate() error {
return errors.New("S3 secret key is required")
}
// Try to create a client to validate the configuration
_, err := s.getClient()
if err != nil {
return fmt.Errorf("invalid S3 configuration: %w", err)
}
return nil
}
func (s *S3Storage) TestConnection() error {
client, err := s.getClient()
func (s *S3Storage) TestConnection(encryptor encryption.FieldEncryptor) error {
client, err := s.getClient(encryptor)
if err != nil {
return err
}
@@ -150,6 +272,7 @@ func (s *S3Storage) TestConnection() error {
// Test write and delete permissions by uploading and removing a small test file
testFileID := uuid.New().String() + "-test"
testObjectKey := s.buildObjectKey(testFileID)
testData := []byte("test connection")
testReader := bytes.NewReader(testData)
@@ -157,7 +280,7 @@ func (s *S3Storage) TestConnection() error {
_, err = client.PutObject(
ctx,
s.S3Bucket,
testFileID,
testObjectKey,
testReader,
int64(len(testData)),
minio.PutObjectOptions{},
@@ -170,7 +293,7 @@ func (s *S3Storage) TestConnection() error {
err = client.RemoveObject(
ctx,
s.S3Bucket,
testFileID,
testObjectKey,
minio.RemoveObjectOptions{},
)
if err != nil {
@@ -185,10 +308,32 @@ func (s *S3Storage) HideSensitiveData() {
s.S3SecretKey = ""
}
func (s *S3Storage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
var err error
if s.S3AccessKey != "" {
s.S3AccessKey, err = encryptor.Encrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 access key: %w", err)
}
}
if s.S3SecretKey != "" {
s.S3SecretKey, err = encryptor.Encrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return fmt.Errorf("failed to encrypt S3 secret key: %w", err)
}
}
return nil
}
func (s *S3Storage) Update(incoming *S3Storage) {
s.S3Bucket = incoming.S3Bucket
s.S3Region = incoming.S3Region
s.S3Endpoint = incoming.S3Endpoint
s.S3UseVirtualHostedStyle = incoming.S3UseVirtualHostedStyle
s.SkipTLSVerify = incoming.SkipTLSVerify
if incoming.S3AccessKey != "" {
s.S3AccessKey = incoming.S3AccessKey
@@ -197,11 +342,75 @@ func (s *S3Storage) Update(incoming *S3Storage) {
if incoming.S3SecretKey != "" {
s.S3SecretKey = incoming.S3SecretKey
}
// we do not allow to change the prefix after creation,
// otherwise we will have to migrate all the data to the new prefix
}
func (s *S3Storage) getClient() (*minio.Client, error) {
endpoint := s.S3Endpoint
useSSL := true
func (s *S3Storage) buildObjectKey(fileName string) string {
if s.S3Prefix == "" {
return fileName
}
prefix := s.S3Prefix
prefix = strings.TrimPrefix(prefix, "/")
if !strings.HasSuffix(prefix, "/") {
prefix = prefix + "/"
}
return prefix + fileName
}
func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Client, error) {
endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, err := s.getClientParams(
encryptor,
)
if err != nil {
return nil, err
}
minioClient, err := minio.New(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
Secure: useSSL,
Region: s.S3Region,
BucketLookup: bucketLookup,
Transport: transport,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize MinIO client: %w", err)
}
return minioClient, nil
}
func (s *S3Storage) getCoreClient(encryptor encryption.FieldEncryptor) (*minio.Core, error) {
endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, err := s.getClientParams(
encryptor,
)
if err != nil {
return nil, err
}
coreClient, err := minio.NewCore(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
Secure: useSSL,
Region: s.S3Region,
BucketLookup: bucketLookup,
Transport: transport,
})
if err != nil {
return nil, fmt.Errorf("failed to initialize MinIO Core client: %w", err)
}
return coreClient, nil
}
func (s *S3Storage) getClientParams(
encryptor encryption.FieldEncryptor,
) (endpoint string, useSSL bool, accessKey string, secretKey string, bucketLookup minio.BucketLookupType, transport *http.Transport, err error) {
endpoint = s.S3Endpoint
useSSL = true
if strings.HasPrefix(endpoint, "http://") {
useSSL = false
@@ -210,20 +419,36 @@ func (s *S3Storage) getClient() (*minio.Client, error) {
endpoint = strings.TrimPrefix(endpoint, "https://")
}
// If no endpoint is provided, use the AWS S3 endpoint for the region
if endpoint == "" {
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", s.S3Region)
}
// Initialize the MinIO client
minioClient, err := minio.New(endpoint, &minio.Options{
Creds: credentials.NewStaticV4(s.S3AccessKey, s.S3SecretKey, ""),
Secure: useSSL,
Region: s.S3Region,
})
accessKey, err = encryptor.Decrypt(s.StorageID, s.S3AccessKey)
if err != nil {
return nil, fmt.Errorf("failed to initialize MinIO client: %w", err)
return "", false, "", "", 0, nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
}
return minioClient, nil
secretKey, err = encryptor.Decrypt(s.StorageID, s.S3SecretKey)
if err != nil {
return "", false, "", "", 0, nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
}
bucketLookup = minio.BucketLookupAuto
if s.S3UseVirtualHostedStyle {
bucketLookup = minio.BucketLookupDNS
}
transport = &http.Transport{
DialContext: (&net.Dialer{
Timeout: s3ConnectTimeout,
}).DialContext,
TLSHandshakeTimeout: s3TLSHandshakeTimeout,
ResponseHeaderTimeout: s3ResponseTimeout,
IdleConnTimeout: s3IdleConnTimeout,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: s.SkipTLSVerify,
},
}
return endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, nil
}

View File

@@ -0,0 +1,430 @@
package sftp_storage
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"net"
"postgresus-backend/internal/util/encryption"
"strings"
"time"
"github.com/google/uuid"
"github.com/pkg/sftp"
"golang.org/x/crypto/ssh"
)
const (
sftpConnectTimeout = 30 * time.Second
sftpTestConnectTimeout = 10 * time.Second
)
type SFTPStorage struct {
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
Host string `json:"host" gorm:"not null;type:text;column:host"`
Port int `json:"port" gorm:"not null;default:22;column:port"`
Username string `json:"username" gorm:"not null;type:text;column:username"`
Password string `json:"password" gorm:"type:text;column:password"`
PrivateKey string `json:"privateKey" gorm:"type:text;column:private_key"`
Path string `json:"path" gorm:"type:text;column:path"`
SkipHostKeyVerify bool `json:"skipHostKeyVerify" gorm:"not null;default:false;column:skip_host_key_verify"`
}
func (s *SFTPStorage) TableName() string {
return "sftp_storages"
}
func (s *SFTPStorage) SaveFile(
ctx context.Context,
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
fileID uuid.UUID,
file io.Reader,
) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.Info("Starting to save file to SFTP storage", "fileId", fileID.String(), "host", s.Host)
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
logger.Error("Failed to connect to SFTP", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to connect to SFTP: %w", err)
}
defer func() {
if closeErr := client.Close(); closeErr != nil {
logger.Error(
"Failed to close SFTP client",
"fileId",
fileID.String(),
"error",
closeErr,
)
}
if closeErr := sshConn.Close(); closeErr != nil {
logger.Error(
"Failed to close SSH connection",
"fileId",
fileID.String(),
"error",
closeErr,
)
}
}()
if s.Path != "" {
if err := s.ensureDirectory(client, s.Path); err != nil {
logger.Error(
"Failed to ensure directory",
"fileId",
fileID.String(),
"path",
s.Path,
"error",
err,
)
return fmt.Errorf("failed to ensure directory: %w", err)
}
}
filePath := s.getFilePath(fileID.String())
logger.Debug("Uploading file to SFTP", "fileId", fileID.String(), "filePath", filePath)
remoteFile, err := client.Create(filePath)
if err != nil {
logger.Error("Failed to create remote file", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to create remote file: %w", err)
}
defer func() {
_ = remoteFile.Close()
}()
ctxReader := &contextReader{ctx: ctx, reader: file}
_, err = io.Copy(remoteFile, ctxReader)
if err != nil {
select {
case <-ctx.Done():
logger.Info("SFTP upload cancelled", "fileId", fileID.String())
return ctx.Err()
default:
logger.Error("Failed to upload file to SFTP", "fileId", fileID.String(), "error", err)
return fmt.Errorf("failed to upload file to SFTP: %w", err)
}
}
logger.Info(
"Successfully saved file to SFTP storage",
"fileId",
fileID.String(),
"filePath",
filePath,
)
return nil
}
func (s *SFTPStorage) GetFile(
encryptor encryption.FieldEncryptor,
fileID uuid.UUID,
) (io.ReadCloser, error) {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
return nil, fmt.Errorf("failed to connect to SFTP: %w", err)
}
filePath := s.getFilePath(fileID.String())
remoteFile, err := client.Open(filePath)
if err != nil {
_ = client.Close()
_ = sshConn.Close()
return nil, fmt.Errorf("failed to open file from SFTP: %w", err)
}
return &sftpFileReader{
file: remoteFile,
client: client,
sshConn: sshConn,
}, nil
}
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
if err != nil {
return fmt.Errorf("failed to connect to SFTP: %w", err)
}
defer func() {
_ = client.Close()
_ = sshConn.Close()
}()
filePath := s.getFilePath(fileID.String())
_, err = client.Stat(filePath)
if err != nil {
return nil
}
err = client.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to delete file from SFTP: %w", err)
}
return nil
}
func (s *SFTPStorage) Validate(encryptor encryption.FieldEncryptor) error {
if s.Host == "" {
return errors.New("SFTP host is required")
}
if s.Username == "" {
return errors.New("SFTP username is required")
}
if s.Password == "" && s.PrivateKey == "" {
return errors.New("SFTP password or private key is required")
}
if s.Port <= 0 || s.Port > 65535 {
return errors.New("SFTP port must be between 1 and 65535")
}
return nil
}
func (s *SFTPStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
ctx, cancel := context.WithTimeout(context.Background(), sftpTestConnectTimeout)
defer cancel()
client, sshConn, err := s.connectWithContext(ctx, encryptor, sftpTestConnectTimeout)
if err != nil {
return fmt.Errorf("failed to connect to SFTP: %w", err)
}
defer func() {
_ = client.Close()
_ = sshConn.Close()
}()
if s.Path != "" {
if err := s.ensureDirectory(client, s.Path); err != nil {
return fmt.Errorf("failed to access or create path '%s': %w", s.Path, err)
}
}
return nil
}
func (s *SFTPStorage) HideSensitiveData() {
s.Password = ""
s.PrivateKey = ""
}
func (s *SFTPStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if s.Password != "" {
encrypted, err := encryptor.Encrypt(s.StorageID, s.Password)
if err != nil {
return fmt.Errorf("failed to encrypt SFTP password: %w", err)
}
s.Password = encrypted
}
if s.PrivateKey != "" {
encrypted, err := encryptor.Encrypt(s.StorageID, s.PrivateKey)
if err != nil {
return fmt.Errorf("failed to encrypt SFTP private key: %w", err)
}
s.PrivateKey = encrypted
}
return nil
}
func (s *SFTPStorage) Update(incoming *SFTPStorage) {
s.Host = incoming.Host
s.Port = incoming.Port
s.Username = incoming.Username
s.SkipHostKeyVerify = incoming.SkipHostKeyVerify
s.Path = incoming.Path
if incoming.Password != "" {
s.Password = incoming.Password
}
if incoming.PrivateKey != "" {
s.PrivateKey = incoming.PrivateKey
}
}
func (s *SFTPStorage) connect(
encryptor encryption.FieldEncryptor,
timeout time.Duration,
) (*sftp.Client, *ssh.Client, error) {
return s.connectWithContext(context.Background(), encryptor, timeout)
}
func (s *SFTPStorage) connectWithContext(
ctx context.Context,
encryptor encryption.FieldEncryptor,
timeout time.Duration,
) (*sftp.Client, *ssh.Client, error) {
var authMethods []ssh.AuthMethod
if s.Password != "" {
password, err := encryptor.Decrypt(s.StorageID, s.Password)
if err != nil {
return nil, nil, fmt.Errorf("failed to decrypt SFTP password: %w", err)
}
authMethods = append(authMethods, ssh.Password(password))
}
if s.PrivateKey != "" {
privateKey, err := encryptor.Decrypt(s.StorageID, s.PrivateKey)
if err != nil {
return nil, nil, fmt.Errorf("failed to decrypt SFTP private key: %w", err)
}
signer, err := ssh.ParsePrivateKey([]byte(privateKey))
if err != nil {
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
}
authMethods = append(authMethods, ssh.PublicKeys(signer))
}
var hostKeyCallback ssh.HostKeyCallback
if s.SkipHostKeyVerify {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
} else {
hostKeyCallback = ssh.InsecureIgnoreHostKey()
}
config := &ssh.ClientConfig{
User: s.Username,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
Timeout: timeout,
}
address := fmt.Sprintf("%s:%d", s.Host, s.Port)
dialer := net.Dialer{Timeout: timeout}
conn, err := dialer.DialContext(ctx, "tcp", address)
if err != nil {
return nil, nil, fmt.Errorf("failed to dial SFTP server: %w", err)
}
sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SSH connection: %w", err)
}
sshClient := ssh.NewClient(sshConn, chans, reqs)
sftpClient, err := sftp.NewClient(sshClient)
if err != nil {
_ = sshClient.Close()
return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err)
}
return sftpClient, sshClient, nil
}
func (s *SFTPStorage) ensureDirectory(client *sftp.Client, path string) error {
path = strings.TrimPrefix(path, "/")
path = strings.TrimSuffix(path, "/")
if path == "" {
return nil
}
parts := strings.Split(path, "/")
currentPath := ""
for _, part := range parts {
if part == "" || part == "." {
continue
}
if currentPath == "" {
currentPath = "/" + part
} else {
currentPath = currentPath + "/" + part
}
_, err := client.Stat(currentPath)
if err != nil {
err = client.Mkdir(currentPath)
if err != nil {
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
}
}
}
return nil
}
func (s *SFTPStorage) getFilePath(filename string) string {
if s.Path == "" {
return filename
}
path := strings.TrimPrefix(s.Path, "/")
path = strings.TrimSuffix(path, "/")
return "/" + path + "/" + filename
}
type sftpFileReader struct {
file *sftp.File
client *sftp.Client
sshConn *ssh.Client
}
func (r *sftpFileReader) Read(p []byte) (n int, err error) {
return r.file.Read(p)
}
func (r *sftpFileReader) Close() error {
var errs []error
if r.file != nil {
if err := r.file.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close file: %w", err))
}
}
if r.client != nil {
if err := r.client.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close SFTP client: %w", err))
}
}
if r.sshConn != nil {
if err := r.sshConn.Close(); err != nil {
errs = append(errs, fmt.Errorf("failed to close SSH connection: %w", err))
}
}
if len(errs) > 0 {
return errs[0]
}
return nil
}
type contextReader struct {
ctx context.Context
reader io.Reader
}
func (r *contextReader) Read(p []byte) (n int, err error) {
select {
case <-r.ctx.Done():
return 0, r.ctx.Err()
default:
return r.reader.Read(p)
}
}

View File

@@ -30,17 +30,33 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
if storage.NASStorage != nil {
storage.NASStorage.StorageID = storage.ID
}
case StorageTypeAzureBlob:
if storage.AzureBlobStorage != nil {
storage.AzureBlobStorage.StorageID = storage.ID
}
case StorageTypeFTP:
if storage.FTPStorage != nil {
storage.FTPStorage.StorageID = storage.ID
}
case StorageTypeSFTP:
if storage.SFTPStorage != nil {
storage.SFTPStorage.StorageID = storage.ID
}
case StorageTypeRclone:
if storage.RcloneStorage != nil {
storage.RcloneStorage.StorageID = storage.ID
}
}
if storage.ID == uuid.Nil {
if err := tx.Create(storage).
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage").
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage", "FTPStorage", "SFTPStorage", "RcloneStorage").
Error; err != nil {
return err
}
} else {
if err := tx.Save(storage).
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage").
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage", "FTPStorage", "SFTPStorage", "RcloneStorage").
Error; err != nil {
return err
}
@@ -75,6 +91,34 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
return err
}
}
case StorageTypeAzureBlob:
if storage.AzureBlobStorage != nil {
storage.AzureBlobStorage.StorageID = storage.ID // Ensure ID is set
if err := tx.Save(storage.AzureBlobStorage).Error; err != nil {
return err
}
}
case StorageTypeFTP:
if storage.FTPStorage != nil {
storage.FTPStorage.StorageID = storage.ID // Ensure ID is set
if err := tx.Save(storage.FTPStorage).Error; err != nil {
return err
}
}
case StorageTypeSFTP:
if storage.SFTPStorage != nil {
storage.SFTPStorage.StorageID = storage.ID // Ensure ID is set
if err := tx.Save(storage.SFTPStorage).Error; err != nil {
return err
}
}
case StorageTypeRclone:
if storage.RcloneStorage != nil {
storage.RcloneStorage.StorageID = storage.ID // Ensure ID is set
if err := tx.Save(storage.RcloneStorage).Error; err != nil {
return err
}
}
}
return nil
@@ -96,6 +140,10 @@ func (r *StorageRepository) FindByID(id uuid.UUID) (*Storage, error) {
Preload("S3Storage").
Preload("GoogleDriveStorage").
Preload("NASStorage").
Preload("AzureBlobStorage").
Preload("FTPStorage").
Preload("SFTPStorage").
Preload("RcloneStorage").
Where("id = ?", id).
First(&s).Error; err != nil {
return nil, err
@@ -113,6 +161,10 @@ func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage
Preload("S3Storage").
Preload("GoogleDriveStorage").
Preload("NASStorage").
Preload("AzureBlobStorage").
Preload("FTPStorage").
Preload("SFTPStorage").
Preload("RcloneStorage").
Where("workspace_id = ?", workspaceID).
Order("name ASC").
Find(&storages).Error; err != nil {
@@ -150,6 +202,30 @@ func (r *StorageRepository) Delete(s *Storage) error {
return err
}
}
case StorageTypeAzureBlob:
if s.AzureBlobStorage != nil {
if err := tx.Delete(s.AzureBlobStorage).Error; err != nil {
return err
}
}
case StorageTypeFTP:
if s.FTPStorage != nil {
if err := tx.Delete(s.FTPStorage).Error; err != nil {
return err
}
}
case StorageTypeSFTP:
if s.SFTPStorage != nil {
if err := tx.Delete(s.SFTPStorage).Error; err != nil {
return err
}
}
case StorageTypeRclone:
if s.RcloneStorage != nil {
if err := tx.Delete(s.RcloneStorage).Error; err != nil {
return err
}
}
}
// Delete the main storage

View File

@@ -7,6 +7,7 @@ import (
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -15,6 +16,7 @@ type StorageService struct {
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *StorageService) SaveStorage(
@@ -44,7 +46,11 @@ func (s *StorageService) SaveStorage(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -61,7 +67,11 @@ func (s *StorageService) SaveStorage(
} else {
storage.WorkspaceID = workspaceID
if err := storage.Validate(); err != nil {
if err := storage.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := storage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -174,7 +184,7 @@ func (s *StorageService) TestStorageConnection(
return errors.New("insufficient permissions to test storage in this workspace")
}
err = storage.TestConnection()
err = storage.TestConnection(s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
storage.LastSaveError = &lastSaveError
@@ -207,7 +217,7 @@ func (s *StorageService) TestStorageConnectionDirect(
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
if err := existingStorage.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -216,7 +226,7 @@ func (s *StorageService) TestStorageConnectionDirect(
usingStorage = storage
}
return usingStorage.TestConnection()
return usingStorage.TestConnection(s.fieldEncryptor)
}
func (s *StorageService) GetStorageByID(

View File

@@ -1,7 +1,7 @@
package users_models
type SecretKey struct {
Secret string `gorm:"column:secret"`
Secret string `gorm:"column:secret" json:"-"`
}
func (SecretKey) TableName() string {

View File

@@ -0,0 +1,12 @@
package users_repositories
var userRepository = &UserRepository{}
var usersSettingsRepository = &UsersSettingsRepository{}
func GetUserRepository() *UserRepository {
return userRepository
}
func GetUsersSettingsRepository() *UsersSettingsRepository {
return usersSettingsRepository
}

View File

@@ -1,36 +0,0 @@
package users_repositories
import (
"errors"
user_models "postgresus-backend/internal/features/users/models"
"postgresus-backend/internal/storage"
"github.com/google/uuid"
"gorm.io/gorm"
)
type SecretKeyRepository struct{}
func (r *SecretKeyRepository) GetSecretKey() (string, error) {
var secretKey user_models.SecretKey
if err := storage.
GetDb().
First(&secretKey).Error; err != nil {
// create a new secret key if not found
if errors.Is(err, gorm.ErrRecordNotFound) {
newSecretKey := user_models.SecretKey{
Secret: uuid.New().String() + uuid.New().String(),
}
if err := storage.GetDb().Create(&newSecretKey).Error; err != nil {
return "", errors.New("failed to create new secret key")
}
return newSecretKey.Secret, nil
}
return "", err
}
return secretKey.Secret, nil
}

View File

@@ -1,25 +1,22 @@
package users_services
import (
user_repositories "postgresus-backend/internal/features/users/repositories"
"postgresus-backend/internal/features/encryption/secrets"
users_repositories "postgresus-backend/internal/features/users/repositories"
)
var secretKeyRepository = &user_repositories.SecretKeyRepository{}
var userRepository = &user_repositories.UserRepository{}
var usersSettingsRepository = &user_repositories.UsersSettingsRepository{}
var userService = &UserService{
userRepository,
secretKeyRepository,
users_repositories.GetUserRepository(),
secrets.GetSecretKeyService(),
settingsService,
nil,
}
var settingsService = &SettingsService{
usersSettingsRepository,
users_repositories.GetUsersSettingsRepository(),
nil,
}
var managementService = &UserManagementService{
userRepository,
users_repositories.GetUserRepository(),
nil,
}

View File

@@ -17,6 +17,7 @@ import (
"golang.org/x/oauth2/google"
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/encryption/secrets"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_interfaces "postgresus-backend/internal/features/users/interfaces"
@@ -25,10 +26,10 @@ import (
)
type UserService struct {
userRepository *users_repositories.UserRepository
secretKeyRepository *users_repositories.SecretKeyRepository
settingsService *SettingsService
auditLogWriter users_interfaces.AuditLogWriter
userRepository *users_repositories.UserRepository
secretKeyService *secrets.SecretKeyService
settingsService *SettingsService
auditLogWriter users_interfaces.AuditLogWriter
}
func (s *UserService) SetAuditLogWriter(writer users_interfaces.AuditLogWriter) {
@@ -162,7 +163,7 @@ func (s *UserService) SignIn(
}
func (s *UserService) GetUserFromToken(token string) (*users_models.User, error) {
secretKey, err := s.secretKeyRepository.GetSecretKey()
secretKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
return nil, fmt.Errorf("failed to get secret key: %w", err)
}
@@ -221,7 +222,7 @@ func (s *UserService) GetUserFromToken(token string) (*users_models.User, error)
func (s *UserService) GenerateAccessToken(
user *users_models.User,
) (*users_dto.SignInResponseDTO, error) {
secretKey, err := s.secretKeyRepository.GetSecretKey()
secretKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
return nil, fmt.Errorf("failed to get secret key: %w", err)
}
@@ -309,15 +310,6 @@ func (s *UserService) ChangeUserPasswordByEmail(email string, newPassword string
}
func (s *UserService) ChangeUserPassword(userID uuid.UUID, newPassword string) error {
user, err := s.userRepository.GetUserByID(userID)
if err != nil {
return fmt.Errorf("failed to get user: %w", err)
}
if !user.HasPassword() {
return errors.New("user has no password set")
}
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)

View File

@@ -309,7 +309,7 @@ func (s *WorkspaceService) GetAllWorkspaces() ([]*workspaces_models.Workspace, e
return s.workspaceRepository.GetAllWorkspaces()
}
func (s *WorkspaceService) GetWorkspaceByIDInternal(
func (s *WorkspaceService) GetWorkspaceByID(
workspaceID uuid.UUID,
) (*workspaces_models.Workspace, error) {
return s.workspaceRepository.GetWorkspaceByID(workspaceID)

View File

@@ -0,0 +1,11 @@
package encryption
import "postgresus-backend/internal/features/encryption/secrets"
var fieldEncryptor = &SecretKeyFieldEncryptor{
secrets.GetSecretKeyService(),
}
func GetFieldEncryptor() FieldEncryptor {
return fieldEncryptor
}

View File

@@ -0,0 +1,15 @@
package encryption
import "github.com/google/uuid"
type FieldEncryptor interface {
// Encrypt encrypts a plaintext string and returns an encrypted string.
// If the string is already encrypted, returns it as-is.
// Empty strings are returned unchanged.
Encrypt(itemID uuid.UUID, plaintext string) (string, error)
// Decrypt decrypts an encrypted string and returns a plaintext string.
// If the string is not encrypted, returns it as-is.
// Empty strings are returned unchanged.
Decrypt(itemID uuid.UUID, ciphertext string) (string, error)
}

Some files were not shown because too many files have changed in this diff Show More