Compare commits

...

154 Commits

Author SHA1 Message Date
Rostislav Dugin
920c98e229 Merge pull request #397 from databasus/develop
FIX (migrations): Fix version of migrations tool goose
2026-02-22 23:43:55 +03:00
Rostislav Dugin
2a19a96aae FIX (migrations): Fix version of migrations tool goose 2026-02-22 23:43:23 +03:00
Rostislav Dugin
75aa2108d9 Merge pull request #396 from databasus/develop
FIX (email): Use current OS hostname instead of default localhost
2026-02-22 23:33:28 +03:00
Rostislav Dugin
0a0040839e FIX (email): Use current OS hostname instead of default localhost 2026-02-22 23:31:25 +03:00
Rostislav Dugin
ff4f795ece Merge pull request #394 from databasus/develop
FIX (nas): Add NAS share validation
2026-02-22 16:05:38 +03:00
Rostislav Dugin
dc05502580 FIX (nas): Add NAS share validation 2026-02-22 15:56:30 +03:00
Rostislav Dugin
1ca38f5583 Merge pull request #390 from databasus/develop
FEATURE (templates): Add PR template
2026-02-21 15:58:21 +03:00
Rostislav Dugin
40b3ff61c7 FEATURE (templates): Add PR template 2026-02-21 15:53:01 +03:00
Rostislav Dugin
e1b245a573 Merge pull request #389 from databasus/develop
Develop
2026-02-21 14:57:56 +03:00
Rostislav Dugin
fdf29b71f2 FIX (mongodb): Fix direct connection string parsing 2026-02-21 14:56:48 +03:00
Rostislav Dugin
49da981c21 Merge pull request #388 from databasus/main
Merge main into dev
2026-02-21 14:53:31 +03:00
Rostislav Dugin
9d611d3559 REFACTOR (mongodb): Refactor direct connection PR 2026-02-21 14:43:47 +03:00
ujstor
22cab53dab feature/mongodb-directConnection (#377)
FEATURE (mongodb): Add direct connection
2026-02-21 14:10:28 +03:00
Rostislav Dugin
d761c4156c Merge pull request #385 from databasus/develop
FIX (readme): Fix README typo
2026-02-20 17:17:45 +03:00
Rostislav Dugin
cbb8b82711 FIX (readme): Fix README typo 2026-02-20 17:01:44 +03:00
Rostislav Dugin
8e3d1e5bff Merge pull request #384 from databasus/develop
FIX (backups): Do not reload backups if request already in progress
2026-02-20 15:04:19 +03:00
Rostislav Dugin
349e7f0ee8 FIX (backups): Do not reload backups if request already in progress 2026-02-20 14:43:07 +03:00
Rostislav Dugin
3a274e135b Merge pull request #383 from databasus/develop
FEATURE (backups): Add GFS retention policy
2026-02-20 14:33:29 +03:00
Rostislav Dugin
61e937bc2a FEATURE (backups): Add GFS retention policy 2026-02-20 14:31:56 +03:00
Rostislav Dugin
f67919fe1a Merge pull request #374 from databasus/develop
FIX (backups): Fix backup download and clean up
2026-02-18 12:53:10 +03:00
Rostislav Dugin
91ee5966d8 FIX (backups): Fix backup download and clean up 2026-02-18 12:52:35 +03:00
Rostislav Dugin
d77d7d69a3 Merge pull request #371 from databasus/develop
FEATURE (backups): Add metadata alongsize with backup files itself to…
2026-02-17 19:54:53 +03:00
Rostislav Dugin
fc88b730d5 FEATURE (backups): Add metadata alongsize with backup files itself to make them recovarable without Databasus 2026-02-17 19:52:08 +03:00
Rostislav Dugin
1f1d80245f Merge pull request #368 from databasus/develop
FIX (restores): Increase restore timeout to 23 hours instead of 1 hour
2026-02-17 14:56:58 +03:00
Rostislav Dugin
16a29cf458 FIX (restores): Increase restore timeout to 23 hours instead of 1 hour 2026-02-17 14:56:25 +03:00
Rostislav Dugin
43e04500ac Merge pull request #367 from databasus/develop
FEATURE (backups): Add meaningful names for backups
2026-02-17 14:50:21 +03:00
Rostislav Dugin
cee3022f85 FEATURE (backups): Add meaningful names for backups 2026-02-17 14:49:33 +03:00
Rostislav Dugin
f46d92c480 Merge pull request #365 from databasus/develop
FIX (audit logs): Get rid of IDs in audit logs and improve naming log…
2026-02-15 01:10:54 +03:00
Rostislav Dugin
10677238d7 FIX (audit logs): Get rid of IDs in audit logs and improve naming logging 2026-02-15 01:06:39 +03:00
Rostislav Dugin
2553203fcf Merge pull request #363 from databasus/develop
FIX (sign up): Return authorization token on sign up to avoid 2-step …
2026-02-15 00:09:00 +03:00
Rostislav Dugin
7b05bd8000 FIX (sign up): Return authorization token on sign up to avoid 2-step sign up 2026-02-15 00:08:01 +03:00
Rostislav Dugin
8d45728f73 Merge pull request #362 from databasus/develop
FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign …
2026-02-14 23:19:12 +03:00
Rostislav Dugin
c70ad82c95 FEATURE (auth): Add optional CloudFlare Turnstile for sign in \ sign up \ password reset 2026-02-14 23:11:36 +03:00
Rostislav Dugin
e4bc34d319 Merge pull request #361 from databasus/develop
Develop
2026-02-13 16:57:25 +03:00
Rostislav Dugin
257ae85da7 FIX (postgres): Fix read-only issue when user cannot access tables and partitions created after user creation 2026-02-13 16:56:56 +03:00
Rostislav Dugin
b42c820bb2 FIX (mariadb): Fix events exclusion 2026-02-13 16:21:48 +03:00
Rostislav Dugin
da5c13fb11 Merge pull request #356 from databasus/develop
FIX (mysql & mariadb): Fix creation of backups with exremely large SQ…
2026-02-10 22:40:06 +03:00
Rostislav Dugin
35180360e5 FIX (mysql & mariadb): Fix creation of backups with exremely large SQL statements to avoid OOM 2026-02-10 22:38:18 +03:00
Rostislav Dugin
e4f6cd7a5d Merge pull request #349 from databasus/develop
Develop
2026-02-09 16:42:00 +03:00
Rostislav Dugin
d7b8e6d56a Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-09 16:40:46 +03:00
Rostislav Dugin
6016f23fb2 FEATURE (svr): Add SVR support 2026-02-09 16:39:51 +03:00
Rostislav Dugin
e7c4ee8f6f Merge pull request #345 from databasus/develop
Develop
2026-02-08 23:38:42 +03:00
Rostislav Dugin
a75702a01b Merge pull request #342 from wuast94/patch-1
Add image source label to dockerfiles
2026-02-08 23:38:18 +03:00
Rostislav Dugin
81a21eb907 FEATURE (google drive): Change OAuth authorization flow to local address instead of databasus.com 2026-02-08 23:32:13 +03:00
Marc
33d6bf0147 Add image source label to dockerfiles
To get changelogs shown with Renovate a docker container has to add the source label described in the OCI Image Format Specification.

For reference: https://github.com/renovatebot/renovate/blob/main/lib/modules/datasource/docker/readme.md
2026-02-05 23:30:37 +01:00
Rostislav Dugin
6eb53bb07b Merge pull request #341 from databasus/develop
Develop
2026-02-06 00:25:30 +03:00
Rostislav Dugin
6ac04270b9 FEATURE (healthcheck): Add checking whether backup nodes available for primary node 2026-02-06 00:24:34 +03:00
Rostislav Dugin
b0510d7c21 FIX (logging): Add login to VictoriaLogs logger 2026-02-06 00:18:09 +03:00
Rostislav Dugin
dc5f271882 Merge pull request #339 from databasus/develop
FIX (storages): Do not remove system storage on any workspace deletion
2026-02-05 01:32:46 +03:00
Rostislav Dugin
8f718771c9 FIX (storages): Do not remove system storage on any workspace deletion 2026-02-05 01:32:21 +03:00
Rostislav Dugin
d8eea05dca Merge pull request #332 from databasus/develop
FIX (script): Fix script creation in playground head x2
2026-02-02 20:46:35 +03:00
Rostislav Dugin
b2a94274d7 FIX (script): Fix script creation in playground head x2 2026-02-02 20:44:52 +03:00
Rostislav Dugin
77c2712ebb Merge pull request #331 from databasus/develop
FIX (script): Fix script creation in playground head
2026-02-02 19:47:44 +03:00
Rostislav Dugin
a9dc29f82c FIX (script): Fix script creation in playground head 2026-02-02 19:47:15 +03:00
Rostislav Dugin
c934a45dca Merge pull request #330 from databasus/develop
FIX (storages): Fix storage edit in playground
2026-02-02 18:51:47 +03:00
Rostislav Dugin
d4acdf2826 FIX (storages): Fix storage edit in playground 2026-02-02 18:48:19 +03:00
Rostislav Dugin
49753c4fc0 Merge pull request #329 from databasus/develop
FIX (s3): Fix S3 prefill in playground on form edit
2026-02-02 18:14:07 +03:00
Rostislav Dugin
c6aed6b36d FIX (s3): Fix S3 prefill in playground on form edit 2026-02-02 18:12:44 +03:00
Rostislav Dugin
3060b4266a Merge pull request #328 from databasus/develop
Develop
2026-02-02 17:53:05 +03:00
Rostislav Dugin
ebeb597f17 FEATURE (playground): Add support of Rybbit script for playground 2026-02-02 17:50:31 +03:00
Rostislav Dugin
4783784325 FIX (playground): Do not show whitelist message in playground 2026-02-02 16:53:01 +03:00
Rostislav Dugin
bd41433bdb Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-02 16:50:18 +03:00
Rostislav Dugin
a9073787d2 FIX (audit logs): In dark mode show white text in audit logs 2026-02-02 16:44:49 +03:00
Rostislav Dugin
0890bf8f09 Merge pull request #327 from artemkalugin01/access-management-href-fix
Fix href in settings for access-management#global-settings
2026-02-02 16:12:25 +03:00
artem.kalugin
f8c11e8802 Fix href typo in settings for access-management#global-settings 2026-02-02 12:59:56 +03:00
Rostislav Dugin
e798d82fc1 Merge pull request #325 from databasus/develop
FIX (storages): Fix default storage type prefill in playground
2026-02-01 20:12:12 +03:00
Rostislav Dugin
81a01585ee FIX (storages): Fix default storage type prefill in playground 2026-02-01 20:07:12 +03:00
Rostislav Dugin
a8465c1a10 Merge pull request #324 from databasus/develop
FIX (storages): Limit local storage usage in playground
2026-02-01 19:20:34 +03:00
Rostislav Dugin
a9e5db70f6 FIX (storages): Limit local storage usage in playground 2026-02-01 19:18:54 +03:00
Rostislav Dugin
7a47be6ca6 Merge pull request #323 from databasus/develop
Develop
2026-02-01 18:42:30 +03:00
Rostislav Dugin
16be3db0c6 FIX (playground): Pre-select system storage if exists in playground 2026-02-01 18:30:50 +03:00
Rostislav Dugin
744e51d1e1 REFACTOR (email): Refactor commit adding date headers to emails 2026-02-01 16:43:53 +03:00
Rostislav Dugin
b3af75d430 Merge branch 'develop' of https://github.com/databasus/databasus into develop 2026-02-01 16:41:52 +03:00
mcarbs
6f7320abeb FIX (email): Add email date header 2026-02-01 16:41:17 +03:00
Rostislav Dugin
a1655d35a6 FIX (healthcheck): Add cache accessibility to healthcheck 2026-01-30 16:33:39 +03:00
Rostislav Dugin
9b6e801184 Merge pull request #316 from databasus/develop
FEATURE (email): Add sending email about members invitation and passw…
2026-01-28 17:29:58 +03:00
Rostislav Dugin
105777ab6f FEATURE (email): Add sending email about members invitation and password reset 2026-01-28 17:28:36 +03:00
Rostislav Dugin
3a1a88d5cf Merge pull request #315 from databasus/develop
FIX (env): Fix env detection over startup
2026-01-28 11:33:06 +03:00
Rostislav Dugin
699ca16814 FIX (env): Fix env detection over startup 2026-01-28 11:32:19 +03:00
Rostislav Dugin
26f3cf233a Merge pull request #313 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x3
2026-01-27 17:04:25 +03:00
Rostislav Dugin
3d8372e9f6 FIX (backups): Improve cascade deletion of backups on storage removal x3 2026-01-27 17:03:51 +03:00
Rostislav Dugin
b46f11804d Merge pull request #312 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal x2
2026-01-27 16:38:49 +03:00
Rostislav Dugin
4676361688 FIX (backups): Improve cascade deletion of backups on storage removal x2 2026-01-27 16:38:21 +03:00
Databasus
de3679cadf Merge pull request #310 from databasus/develop
FIX (backups): Improve cascade deletion of backups on storage removal
2026-01-27 16:29:13 +03:00
Rostislav Dugin
8f03a30af2 FIX (backups): Improve cascade deletion of backups on storage removal 2026-01-27 16:28:06 +03:00
Rostislav Dugin
356529c58a Merge pull request #309 from databasus/develop
FIX (tests): Fix database backups cleanup when DI does not allow to d…
2026-01-27 15:39:53 +03:00
Rostislav Dugin
e7eed056f7 FIX (tests): Fix database backups cleanup when DI does not allow to delete backups via listeners 2026-01-27 15:39:04 +03:00
Rostislav Dugin
6084cdc954 Merge pull request #308 from databasus/develop
FIX (tests): Increase cascade deletion timeouts in tests
2026-01-27 15:24:15 +03:00
Rostislav Dugin
c50bcc57b1 FIX (tests): Increase cascade deletion timeouts in tests 2026-01-27 15:23:13 +03:00
Rostislav Dugin
ea76300ed7 Merge pull request #307 from databasus/develop
Develop
2026-01-27 15:07:56 +03:00
Rostislav Dugin
9b413e4076 FIX (tests): Improve cleaning up of backups and workspaces 2026-01-27 15:07:20 +03:00
Rostislav Dugin
f91cb260f2 FEATURE (logs): Add Victora Logs 2026-01-27 15:07:20 +03:00
Rostislav Dugin
8f37a8082f FIX (db): Decrease connections count for DB 2026-01-27 15:07:20 +03:00
Rostislav Dugin
5cf7614772 FIX (playground): Make playground multiple nodes 2026-01-24 14:57:45 +03:00
Rostislav Dugin
ae27f74c2e Merge pull request #304 from databasus/develop
FIX (playground): Fix flacky test with impossible value
2026-01-23 12:38:06 +03:00
Rostislav Dugin
9457516bb9 FIX (playground): Fix flacky test with impossible value 2026-01-23 12:37:10 +03:00
Rostislav Dugin
a36fc5bf8c Merge pull request #303 from databasus/develop
Develop
2026-01-23 12:24:29 +03:00
Rostislav Dugin
03ada5806d FEATURE (pre-commit): Add building step to pre-commit 2026-01-23 12:22:31 +03:00
Rostislav Dugin
a6675390e5 FIX (cors): Allow CORS for healthcheck endpoint 2026-01-23 12:04:29 +03:00
Rostislav Dugin
af2f978876 FEATURE (playground): Add playground 2026-01-23 12:00:56 +03:00
Rostislav Dugin
04e7eba5c5 Merge pull request #300 from databasus/develop
FIX (ci \ cd): Add build step after lint step for frontend to catch b…
2026-01-20 08:40:14 +03:00
Rostislav Dugin
520165541d FIX (ci \ cd): Add build step after lint step for frontend to catch build issues 2026-01-20 08:39:28 +03:00
Rostislav Dugin
5b556bc161 Merge pull request #299 from databasus/develop
Develop
2026-01-20 08:26:57 +03:00
Rostislav Dugin
0952a15ec5 FEATURE (navbar): Update navbar style 2026-01-20 08:25:58 +03:00
Rostislav Dugin
1afb3aa3ff Merge pull request #298 from tim-sas-kramp/main
FIX (theme): Integrate theme support for GitHub button color scheme
2026-01-20 07:25:57 +03:00
tim-sas-kramp
19b92e5f74 FIX (theme): Integrate theme support for GitHub button color scheme 2026-01-19 21:17:24 +00:00
Rostislav Dugin
d4763f26b2 Merge pull request #296 from databasus/develop
Develop
2026-01-19 19:27:03 +03:00
Rostislav Dugin
0e389ba16b FIX (backups): Allow parallel backups for different DBs 2026-01-19 19:26:03 +03:00
Rostislav Dugin
594a3294c6 FEATURE (limits): Add max backup size limit and total backups size limit 2026-01-19 19:26:03 +03:00
Rostislav Dugin
4e4a323cf1 FEATURE (config): Suggest read-only user creation when DB config changed 2026-01-19 19:26:03 +03:00
Rostislav Dugin
7d9ecf697b FIX (backups): Do not allow 2 parallel backups for the same DB 2026-01-19 19:26:03 +03:00
Rostislav Dugin
755c420157 Merge pull request #294 from databasus/develop
FIX (mysql \ mariadb): Add escaping underscoped DB names over heath c…
2026-01-19 12:07:18 +03:00
Rostislav Dugin
ff73627287 FIX (mysql \ mariadb): Add escaping underscoped DB names over heath check 2026-01-19 11:34:37 +03:00
Rostislav Dugin
9c9ab00ace Merge pull request #292 from databasus/develop
FIX (postgresql): Do not throw an error over read-only user creation …
2026-01-18 23:08:55 +03:00
Rostislav Dugin
7366e21a1a FIX (postgresql): Do not throw an error over read-only user creation if there are no public schema in DB 2026-01-18 22:57:47 +03:00
Rostislav Dugin
a327d1aa57 Merge pull request #290 from databasus/develop
FIX (ftp): Add support of nested folders
2026-01-18 18:34:45 +03:00
Rostislav Dugin
f152b16ea3 FIX (ftp): Add support of nested folders 2026-01-18 18:34:13 +03:00
Databasus
85dbe80d3d Merge pull request #288 from databasus/develop
FIX (email): Add following RFC 2047 for emails
2026-01-18 17:59:17 +03:00
Rostislav Dugin
edf4028fd1 FIX (email): Add following RFC 2047 for emails 2026-01-18 17:58:31 +03:00
Databasus
8d85c45a90 Merge pull request #287 from databasus/develop
FIX (tests): Allow to skip external network tests in CI CD
2026-01-18 15:46:49 +03:00
Rostislav Dugin
d9c176d19a FIX (tests): Allow to skip external network tests in CI CD 2026-01-18 15:45:49 +03:00
Databasus
7a6f72a456 Merge pull request #286 from databasus/develop
FIX (ci): Add cleanup to build and push steps
2026-01-18 15:09:13 +03:00
Rostislav Dugin
9a1471b88b FIX (ci): Add cleanup to build and push steps 2026-01-18 15:08:09 +03:00
Databasus
386ea1d708 Merge pull request #285 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x3
2026-01-18 14:58:10 +03:00
Rostislav Dugin
a4b23936ee FIX (commit messages): Allow to use backstashes in messages x3 2026-01-18 14:57:45 +03:00
Databasus
b36aa9d48b Merge pull request #284 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages x2
2026-01-18 14:49:58 +03:00
Rostislav Dugin
13cb8e5bd2 FIX (commit messages): Allow to use backstashes in messages x2 2026-01-18 14:49:18 +03:00
Databasus
2db4b6e075 Merge pull request #283 from databasus/develop
FIX (commit messages): Allow to use backstashes in messages
2026-01-18 14:38:34 +03:00
Rostislav Dugin
f2b0b2bf1f FIX (commit messages): Allow to use backstashes in messages 2026-01-18 14:38:12 +03:00
Databasus
7142ce295e Merge pull request #282 from databasus/develop
Develop
2026-01-18 14:01:59 +03:00
Rostislav Dugin
04621b9b2d FEATURE (ci \ cd): Adjust CI \ CD to run heavy jobs on self hosted performant runner 2026-01-18 13:55:08 +03:00
Rostislav Dugin
bd329a68cf FEATURE (restores): Do not allow to make 2 parallel restores for single DB 2026-01-17 22:50:35 +03:00
Rostislav Dugin
f957abc9db FEATURE (restores): Add cancellation of restore process 2026-01-17 22:35:47 +03:00
Rostislav Dugin
c0fd6be1a9 Merge pull request #280 from databasus/develop
FEATURE (restores): Add support of multiple restores nodes
2026-01-17 13:59:36 +03:00
Rostislav Dugin
c39bd34d5e FEATURE (restores): Add support of multiple restores nodes 2026-01-17 13:59:06 +03:00
Rostislav Dugin
27bec15a29 Merge pull request #278 from databasus/develop
FIX (backups): Extend filtering lists to detect from-image DB access
2026-01-16 10:03:45 +03:00
Rostislav Dugin
d98baa0656 FIX (backups): Extend filtering lists to detect from-image DB access 2026-01-16 10:03:09 +03:00
Rostislav Dugin
4344f5ea5e Merge pull request #273 from databasus/develop
FIX (ci \ cd): Make DB files in CI \ CD executable
2026-01-15 22:17:06 +03:00
Rostislav Dugin
7c6afa5b88 FIX (ci \ cd): Make DB files in CI \ CD executable 2026-01-15 22:16:45 +03:00
Rostislav Dugin
dbac799e1b Merge pull request #272 from databasus/develop
FIX (backups): Add backups failure logging when it is expected
2026-01-15 22:02:39 +03:00
Rostislav Dugin
7ee3817089 FIX (backups): Add backups failure logging when it is expected 2026-01-15 22:01:53 +03:00
Rostislav Dugin
bae6f7f007 Merge pull request #271 from databasus/develop
Develop
2026-01-15 21:19:55 +03:00
Rostislav Dugin
55dc087ddd FIX (containers): Do not allow to backup internal DB from inside containers, instead give link to FAQ with manual how to backup Databasus in proper way 2026-01-15 21:18:37 +03:00
Rostislav Dugin
c94d0db637 FIX (ci \ cd): Remove caches and use assets from repo to avoid flucky tests over CI 2026-01-15 21:03:43 +03:00
Rostislav Dugin
a1adef2261 !REFACTOR (tasks): Move tasks cancellation and tracking to separate package from backuping to use for restores 2026-01-15 21:03:05 +03:00
Rostislav Dugin
4602dc3f88 Merge pull request #267 from databasus/develop
FIX (mysql): Enable allowCleartextPasswords over SSL
2026-01-14 18:13:46 +03:00
Rostislav Dugin
cbbfc5ea8f FIX (mysql): Enable allowCleartextPasswords over SSL 2026-01-14 18:11:49 +03:00
Rostislav Dugin
dd1072e230 Merge pull request #265 from databasus/develop
FIX (pre-commit): Add running go mod tidy in pre-commit
2026-01-14 15:18:35 +03:00
Rostislav Dugin
a495e5317a FIX (pre-commit): Add running go mod tidy in pre-commit 2026-01-14 15:18:06 +03:00
Rostislav Dugin
7eed647038 Merge pull request #264 from databasus/develop
Develop
2026-01-14 15:14:05 +03:00
Rostislav Dugin
6973241e25 FIX (backups): Throw error on parallel download token generation 2026-01-14 14:40:22 +03:00
Rostislav Dugin
ab181f5b81 FEATURE (bandwidth): Limit download throughput for backups to not exhaust more than 75% of server network bandwidth 2026-01-14 14:40:22 +03:00
Rostislav Dugin
b60a0cc170 FEATURE (backups): Allow single backup download to avoid exhausting of server throughput 2026-01-14 14:40:22 +03:00
Rostislav Dugin
f319a497b3 FEATURE (auth): Add rate limiting for sign in via email using sliding window 2026-01-14 14:40:22 +03:00
278 changed files with 23744 additions and 4261 deletions

38
.github/ISSUE_TEMPLATE/bug_report.md vendored Normal file
View File

@@ -0,0 +1,38 @@
---
name: Bug Report
about: Report a bug or unexpected behavior in Databasus
labels: bug
---
## Databasus version
<!-- e.g. 1.4.2 -->
## Operating system and architecture
<!-- e.g. Ubuntu 22.04 x64, macOS 14 ARM, Windows 11 x64 -->
## Describe the bug (please write manually, do not ask AI to summarize)
**What happened:**
**What I expected:**
## Steps to reproduce
1.
2.
3.
## Have you asked AI how to solve the issue?
<!-- Using AI to diagnose issues before filing a bug report helps narrow down root causes. -->
- [ ] Claude Sonnet 4.6 or newer
- [ ] ChatGPT 5.2 or newer
- [ ] No
## Additional context / logs
<!-- Screenshots, error messages, relevant log output, etc. -->

View File

@@ -9,25 +9,26 @@ on:
jobs:
lint-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: golang:1.24.9
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/golangci-lint:/root/.cache/golangci-lint
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Download Go modules
run: |
cd backend
go mod download
- name: Install golangci-lint
run: |
@@ -63,8 +64,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -82,6 +81,11 @@ jobs:
cd frontend
npm run lint
- name: Build frontend
run: |
cd frontend
npm run build
test-frontend:
runs-on: ubuntu-latest
needs: [lint-frontend]
@@ -93,8 +97,6 @@ jobs:
uses: actions/setup-node@v4
with:
node-version: "20"
cache: "npm"
cache-dependency-path: frontend/package-lock.json
- name: Install dependencies
run: |
@@ -107,44 +109,32 @@ jobs:
npm run test
test-backend:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [lint-backend]
container:
image: golang:1.24.9
options: --privileged -v /var/run/docker.sock:/var/run/docker.sock --add-host=host.docker.internal:host-gateway
volumes:
- /runner-cache/go-pkg:/go/pkg/mod
- /runner-cache/go-build:/root/.cache/go-build
- /runner-cache/apt-archives:/var/cache/apt/archives
steps:
- name: Free up disk space
- name: Install Docker CLI
run: |
echo "Disk space before cleanup:"
df -h
# Remove unnecessary pre-installed software
sudo rm -rf /usr/share/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /usr/local/share/boost
sudo rm -rf /usr/share/swift
# Clean apt cache
sudo apt-get clean
# Clean docker images (if any pre-installed)
docker system prune -af --volumes || true
echo "Disk space after cleanup:"
df -h
apt-get update -qq
apt-get install -y -qq docker.io docker-compose netcat-openbsd wget
- name: Check out code
uses: actions/checkout@v4
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version: "1.24.9"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Cache Go modules
uses: actions/cache@v4
with:
path: |
~/go/pkg/mod
~/.cache/go-build
key: ${{ runner.os }}-go-${{ hashFiles('backend/go.sum') }}
restore-keys: |
${{ runner.os }}-go-
- name: Download Go modules
run: |
cd backend
go mod download
- name: Create .env file for testing
run: |
@@ -156,14 +146,16 @@ jobs:
DEV_DB_PASSWORD=Q1234567
#app
ENV_MODE=development
# db
DATABASE_DSN=host=localhost user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
# db - using 172.17.0.1 to access host from container
DATABASE_DSN=host=172.17.0.1 user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
# migrations
GOOSE_DRIVER=postgres
GOOSE_DBSTRING=postgres://postgres:Q1234567@localhost:5437/databasus?sslmode=disable
GOOSE_DBSTRING=postgres://postgres:Q1234567@172.17.0.1:5437/databasus?sslmode=disable
GOOSE_MIGRATION_DIR=./migrations
# testing
# testing
TEST_LOCALHOST=172.17.0.1
IS_SKIP_EXTERNAL_RESOURCES_TESTS=true
# to get Google Drive env variables: add storage in UI and copy data from added storage here
TEST_GOOGLE_DRIVE_CLIENT_ID=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_ID }}
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
@@ -221,12 +213,14 @@ jobs:
TEST_MONGODB_60_PORT=27060
TEST_MONGODB_70_PORT=27070
TEST_MONGODB_82_PORT=27082
# Valkey (cache)
VALKEY_HOST=localhost
# Valkey (cache) - using 172.17.0.1
VALKEY_HOST=172.17.0.1
VALKEY_PORT=6379
VALKEY_USERNAME=
VALKEY_PASSWORD=
VALKEY_IS_SSL=false
# Host for test databases (container -> host)
TEST_DB_HOST=172.17.0.1
EOF
- name: Start test containers
@@ -244,25 +238,25 @@ jobs:
timeout 60 bash -c 'until docker exec dev-valkey valkey-cli ping 2>/dev/null | grep -q PONG; do sleep 2; done'
echo "Valkey is ready!"
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5005; do sleep 2; done'
# Wait for test databases (using 172.17.0.1 from container)
timeout 60 bash -c 'until nc -z 172.17.0.1 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5003; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5004; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 5005; do sleep 2; done'
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 10000; do sleep 2; done'
# Wait for FTP
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7007; do sleep 2; done'
# Wait for SFTP
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
timeout 60 bash -c 'until nc -z 172.17.0.1 7008; do sleep 2; done'
# Wait for MySQL containers
echo "Waiting for MySQL 5.7..."
@@ -321,67 +315,66 @@ jobs:
mkdir -p databasus-data/backups
mkdir -p databasus-data/temp
- name: Cache PostgreSQL client tools
id: cache-postgres
uses: actions/cache@v4
with:
path: /usr/lib/postgresql
key: postgres-clients-12-18-v1
- name: Cache MySQL client tools
id: cache-mysql
uses: actions/cache@v4
with:
path: backend/tools/mysql
key: mysql-clients-57-80-84-9-v1
- name: Cache MariaDB client tools
id: cache-mariadb
uses: actions/cache@v4
with:
path: backend/tools/mariadb
key: mariadb-clients-106-121-v1
- name: Cache MongoDB Database Tools
id: cache-mongodb
uses: actions/cache@v4
with:
path: backend/tools/mongodb
key: mongodb-database-tools-100.10.0-v1
- name: Install MySQL dependencies
- name: Install database client dependencies
run: |
sudo apt-get update -qq
sudo apt-get install -y -qq libncurses6
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
apt-get update -qq
apt-get install -y -qq libncurses6 libpq5
ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5 || true
ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5 || true
- name: Install PostgreSQL, MySQL, MariaDB and MongoDB client tools
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true' || steps.cache-mongodb.outputs.cache-hit != 'true'
run: |
chmod +x backend/tools/download_linux.sh
cd backend/tools
./download_linux.sh
- name: Setup PostgreSQL symlinks (when using cache)
if: steps.cache-postgres.outputs.cache-hit == 'true'
- name: Setup PostgreSQL, MySQL and MariaDB client tools from pre-built assets
run: |
cd backend/tools
mkdir -p postgresql
# Create directory structure
mkdir -p postgresql mysql mariadb mongodb/bin
# Copy PostgreSQL client tools (12-18) from pre-built assets
for version in 12 13 14 15 16 17 18; do
version_dir="postgresql/postgresql-$version"
mkdir -p "$version_dir/bin"
pg_bin_dir="/usr/lib/postgresql/$version/bin"
if [ -d "$pg_bin_dir" ]; then
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
fi
mkdir -p postgresql/postgresql-$version
cp -r ../../assets/tools/x64/postgresql/postgresql-$version/bin postgresql/postgresql-$version/
done
# Copy MySQL client tools (5.7, 8.0, 8.4, 9) from pre-built assets
for version in 5.7 8.0 8.4 9; do
mkdir -p mysql/mysql-$version
cp -r ../../assets/tools/x64/mysql/mysql-$version/bin mysql/mysql-$version/
done
# Copy MariaDB client tools (10.6, 12.1) from pre-built assets
for version in 10.6 12.1; do
mkdir -p mariadb/mariadb-$version
cp -r ../../assets/tools/x64/mariadb/mariadb-$version/bin mariadb/mariadb-$version/
done
# Make all binaries executable
chmod +x postgresql/*/bin/*
chmod +x mysql/*/bin/*
chmod +x mariadb/*/bin/*
echo "Pre-built client tools setup complete"
- name: Install MongoDB Database Tools
run: |
cd backend/tools
# MongoDB Database Tools must be downloaded (not in pre-built assets)
# They are backward compatible - single version supports all servers (4.0-8.0)
MONGODB_TOOLS_URL="https://fastdl.mongodb.org/tools/db/mongodb-database-tools-debian12-x86_64-100.10.0.deb"
echo "Downloading MongoDB Database Tools..."
wget -q "$MONGODB_TOOLS_URL" -O /tmp/mongodb-database-tools.deb
echo "Installing MongoDB Database Tools..."
dpkg -i /tmp/mongodb-database-tools.deb || apt-get install -f -y --no-install-recommends
# Create symlinks to tools directory
ln -sf /usr/bin/mongodump mongodb/bin/mongodump
ln -sf /usr/bin/mongorestore mongodb/bin/mongorestore
rm -f /tmp/mongodb-database-tools.deb
echo "MongoDB Database Tools installed successfully"
- name: Verify MariaDB client tools exist
run: |
cd backend/tools
@@ -414,7 +407,7 @@ jobs:
- name: Run database migrations
run: |
cd backend
go install github.com/pressly/goose/v3/cmd/goose@latest
go install github.com/pressly/goose/v3/cmd/goose@v3.24.3
goose up
- name: Run Go tests
@@ -426,10 +419,28 @@ jobs:
if: always()
run: |
cd backend
# Stop and remove containers (keeping images for next run)
docker compose -f docker-compose.yml.example down -v
# Clean up all data directories created by docker-compose
echo "Cleaning up data directories..."
rm -rf pgdata || true
rm -rf valkey-data || true
rm -rf mysqldata || true
rm -rf mariadbdata || true
rm -rf temp/nas || true
rm -rf databasus-data || true
# Also clean root-level databasus-data if exists
cd ..
rm -rf databasus-data || true
echo "Cleanup complete"
determine-version:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
outputs:
@@ -442,10 +453,9 @@ jobs:
with:
fetch-depth: 0
- name: Set up Node.js
uses: actions/setup-node@v4
with:
node-version: "20"
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install semver
run: npm install -g semver
@@ -459,6 +469,7 @@ jobs:
- name: Analyze commits and determine version bump
id: version_bump
shell: bash
run: |
CURRENT_VERSION="${{ steps.current_version.outputs.current_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -478,7 +489,7 @@ jobs:
HAS_FIX=false
HAS_BREAKING=false
# Analyze each commit
# Analyze each commit - USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r commit; do
if [[ "$commit" =~ ^FEATURE ]]; then
HAS_FEATURE=true
@@ -496,7 +507,7 @@ jobs:
HAS_BREAKING=true
echo "Found BREAKING CHANGE: $commit"
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Determine version bump
if [ "$HAS_BREAKING" = true ]; then
@@ -522,10 +533,15 @@ jobs:
fi
build-only:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [test-backend, test-frontend]
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -554,12 +570,17 @@ jobs:
databasus/databasus:${{ github.sha }}
build-and-push:
runs-on: ubuntu-latest
runs-on: self-hosted
needs: [determine-version]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
steps:
- name: Clean workspace
run: |
sudo rm -rf "$GITHUB_WORKSPACE"/* || true
sudo rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
@@ -589,21 +610,33 @@ jobs:
databasus/databasus:${{ github.sha }}
release:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: node:20
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: write
pull-requests: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Check out code
uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Generate changelog
id: changelog
shell: bash
run: |
NEW_VERSION="${{ needs.determine-version.outputs.new_version }}"
LATEST_TAG=$(git describe --tags --abbrev=0 2>/dev/null || echo "v0.0.0")
@@ -623,6 +656,7 @@ jobs:
FIXES=""
REFACTORS=""
# USE PROCESS SUBSTITUTION to avoid subshell variable scope issues
while IFS= read -r line; do
if [ -n "$line" ]; then
COMMIT_MSG=$(echo "$line" | cut -d'|' -f1)
@@ -656,7 +690,7 @@ jobs:
fi
fi
fi
done <<< "$COMMITS"
done < <(printf '%s\n' "$COMMITS")
# Build changelog sections
if [ -n "$FEATURES" ]; then
@@ -695,16 +729,33 @@ jobs:
prerelease: false
publish-helm-chart:
runs-on: ubuntu-latest
runs-on: self-hosted
container:
image: alpine:3.19
volumes:
- /runner-cache/apk-cache:/etc/apk/cache
needs: [determine-version, build-and-push]
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
permissions:
contents: read
packages: write
steps:
- name: Clean workspace
run: |
rm -rf "$GITHUB_WORKSPACE"/* || true
rm -rf "$GITHUB_WORKSPACE"/.* || true
- name: Install dependencies
run: |
apk add --no-cache git bash curl
- name: Check out code
uses: actions/checkout@v4
- name: Configure Git for container
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Set up Helm
uses: azure/setup-helm@v4
with:

4
.gitignore vendored
View File

@@ -1,3 +1,4 @@
ansible/
postgresus_data/
postgresus-data/
databasus-data/
@@ -9,4 +10,5 @@ node_modules/
/articles
.DS_Store
/scripts
/scripts
.vscode/settings.json

View File

@@ -18,6 +18,13 @@ repos:
files: ^frontend/.*\.(ts|tsx|js|jsx)$
pass_filenames: false
- id: frontend-build
name: Frontend Build
entry: bash -c "cd frontend && npm run build"
language: system
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css)$
pass_filenames: false
# Backend checks
- repo: local
hooks:
@@ -27,3 +34,10 @@ repos:
language: system
files: ^backend/.*\.go$
pass_filenames: false
- id: backend-go-mod-tidy
name: Backend Go Mod Tidy
entry: bash -c "cd backend && go mod tidy"
language: system
files: ^backend/.*\.go$
pass_filenames: false

1796
AGENTS.md Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -251,6 +251,37 @@ fi
# PostgreSQL 17 binary paths
PG_BIN="/usr/lib/postgresql/17/bin"
# Generate runtime configuration for frontend
echo "Generating runtime configuration..."
# Detect if email is configured (both SMTP_HOST and DATABASUS_URL must be set)
if [ -n "\${SMTP_HOST:-}" ] && [ -n "\${DATABASUS_URL:-}" ]; then
IS_EMAIL_CONFIGURED="true"
else
IS_EMAIL_CONFIGURED="false"
fi
cat > /app/ui/build/runtime-config.js <<JSEOF
// Runtime configuration injected at container startup
// This file is generated dynamically and should not be edited manually
window.__RUNTIME_CONFIG__ = {
IS_CLOUD: '\${IS_CLOUD:-false}',
GITHUB_CLIENT_ID: '\${GITHUB_CLIENT_ID:-}',
GOOGLE_CLIENT_ID: '\${GOOGLE_CLIENT_ID:-}',
IS_EMAIL_CONFIGURED: '\$IS_EMAIL_CONFIGURED',
CLOUDFLARE_TURNSTILE_SITE_KEY: '\${CLOUDFLARE_TURNSTILE_SITE_KEY:-}'
};
JSEOF
# Inject analytics script if provided (only if not already injected)
if [ -n "\${ANALYTICS_SCRIPT:-}" ]; then
if ! grep -q "rybbit.databasus.com" /app/ui/build/index.html 2>/dev/null; then
echo "Injecting analytics script..."
sed -i "s#</head># \${ANALYTICS_SCRIPT}\\
</head>#" /app/ui/build/index.html
fi
fi
# Ensure proper ownership of data directory
echo "Setting up data directory permissions..."
mkdir -p /databasus-data/pgdata
@@ -372,9 +403,37 @@ SQL
# Start the main application
echo "Starting Databasus application..."
# Check and warn about external database/Valkey usage
if [ -n "\${DANGEROUS_EXTERNAL_DATABASE_DSN:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external database"
echo "=========================================="
echo "DANGEROUS_EXTERNAL_DATABASE_DSN is set."
echo "Application will connect to external PostgreSQL instead of internal instance."
echo "Internal PostgreSQL is still running in the background."
echo "=========================================="
echo ""
fi
if [ -n "\${DANGEROUS_VALKEY_HOST:-}" ]; then
echo ""
echo "=========================================="
echo "WARNING: Using external Valkey"
echo "=========================================="
echo "DANGEROUS_VALKEY_HOST is set."
echo "Application will connect to external Valkey instead of internal instance."
echo "Internal Valkey is still running in the background."
echo "=========================================="
echo ""
fi
exec ./main
EOF
LABEL org.opencontainers.image.source="https://github.com/databasus/databasus"
RUN chmod +x /app/start.sh
EXPOSE 4005
@@ -383,4 +442,4 @@ EXPOSE 4005
VOLUME ["/databasus-data"]
ENTRYPOINT ["/app/start.sh"]
CMD []
CMD []

View File

@@ -11,7 +11,7 @@
[![MongoDB](https://img.shields.io/badge/MongoDB-47A248?logo=mongodb&logoColor=white)](https://www.mongodb.com/)
<br />
[![Apache 2.0 License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE)
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Docker Pulls](https://img.shields.io/docker/pulls/databasus/databasus?color=brightgreen)](https://hub.docker.com/r/databasus/databasus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/databasus/databasus)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/databasus/databasus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/databasus/databasus)
@@ -31,8 +31,6 @@
<img src="assets/dashboard-dark.svg" alt="Databasus Dark Dashboard" width="800" style="margin-bottom: 10px;"/>
<img src="assets/dashboard.svg" alt="Databasus Dashboard" width="800"/>
</div>
---
@@ -52,6 +50,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗑️ **Retention policies**
- **Time period**: Keep backups for a fixed duration (e.g., 7 days, 3 months, 1 year)
- **Count**: Keep a fixed number of the most recent backups (e.g., last 30)
- **GFS (Grandfather-Father-Son)**: Layered retention — keep hourly, daily, weekly, monthly and yearly backups independently for fine-grained long-term history (enterprises requirement)
- **Size limits**: Set per-backup and total storage size caps to control storage usage
### 🗄️ **Multiple storage destinations** <a href="https://databasus.com/storages">(view supported)</a>
- **Local storage**: Keep backups on your VPS/server
@@ -71,6 +76,8 @@
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
- **Read-only user**: Databasus uses a read-only user by default for backups and never stores anything that can modify your data
It is also important for Databasus that you are able to decrypt and restore backups from storages (local, S3, etc.) without Databasus itself. To do so, read our guide on [how to recover directly from storage](https://databasus.com/how-to-recover-without-databasus). We avoid "vendor lock-in" even to open source tool!
### 👥 **Suitable for teams** <a href="https://databasus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
@@ -220,8 +227,9 @@ For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
4. **Set database connection**: Enter your database credentials and connection details
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
7. **Save and start**: Databasus will validate settings and begin the backup schedule
6. **Configure retention policy**: Choose time period, count or GFS to control how long backups are kept
7. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
8. **Save and start**: Databasus will validate settings and begin the backup schedule
### 🔑 Resetting password <a href="https://databasus.com/password">(docs)</a>
@@ -233,56 +241,22 @@ docker exec -it databasus ./main --new-password="YourNewSecurePassword123" --ema
Replace `admin` with the actual email address of the user whose password you want to reset.
### 💾 Backuping Databasus itself
After installation, it is also recommended to <a href="https://databasus.com/faq/#backup-databasus">backup your Databasus itself</a> or, at least, to copy secret key used for encryption (30 seconds is needed). So you are able to restore from your encrypted backups if you lose access to the server with Databasus or it is corrupted.
---
## 📝 License
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
---
## 🤝 Contributing
Contributions are welcome! Read the <a href="https://databasus.com/contribute">contributing guide</a> for more details, priorities and rules. If you want to contribute but don't know where to start, message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
Also you can join our large community of developers, DBAs and DevOps engineers on Telegram [@databasus_community](https://t.me/databasus_community).
--
## 📖 Migration guide
Databasus is the new name for Postgresus. You can stay with latest version of Postgresus if you wish. If you want to migrate - follow installation steps for Databasus itself.
Just renaming an image is not enough as Postgresus and Databasus use different data folders and internal database naming.
You can put a new Databasus image with updated volume near the old Postgresus and run it (stop Postgresus before):
```
services:
databasus:
container_name: databasus
image: databasus/databasus:latest
ports:
- "4005:4005"
volumes:
- ./databasus-data:/databasus-data
restart: unless-stopped
```
Then manually move databases from Postgresus to Databasus.
### Why was Postgresus renamed to Databasus?
Databasus has been developed since 2023. It was internal tool to backup production and home projects databases. In start of 2025 it was released as open source project on GitHub. By the end of 2025 it became popular and the time for renaming has come in December 2025.
It was an important step for the project to grow. Actually, there are a couple of reasons:
1. Postgresus is no longer a little tool that just adds UI for pg_dump for little projects. It became a tool both for individual users, DevOps, DBAs, teams, companies and even large enterprises. Tens of thousands of users use Postgresus every day. Postgresus grew into a reliable backup management tool. Initial positioning is no longer suitable: the project is not just a UI wrapper, it's a solid backup management system now (despite it's still easy to use).
2. New databases are supported: although the primary focus is PostgreSQL (with 100% support in the most efficient way) and always will be, Databasus added support for MySQL, MariaDB and MongoDB. Later more databases will be supported.
3. Trademark issue: "postgres" is a trademark of PostgreSQL Inc. and cannot be used in the project name. So for safety and legal reasons, we had to rename the project.
## AI disclaimer
There have been questions about AI usage in project development in issues and discussions. As the project focuses on security, reliability and production usage, it's important to explain how AI is used in the development process.

View File

@@ -1,152 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always place private methods to the bottom of file
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
## Structure Order:
1. Type definitions and constants
2. Public methods/functions (uppercase)
3. Private methods/functions (lowercase)
## Examples:
### Service with methods:
```go
type UserService struct {
repository *UserRepository
}
// Public methods first
func (s *UserService) CreateUser(user *User) error {
if err := s.validateUser(user); err != nil {
return err
}
return s.repository.Save(user)
}
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
return s.repository.FindByID(id)
}
// Private methods at the bottom
func (s *UserService) validateUser(user *User) error {
if user.Name == "" {
return errors.New("name is required")
}
return nil
}
```
### Package-level functions:
```go
package utils
// Public functions first
func ProcessData(data []byte) (Result, error) {
cleaned := sanitizeInput(data)
return parseData(cleaned)
}
func ValidateInput(input string) bool {
return isValidFormat(input) && checkLength(input)
}
// Private functions at the bottom
func sanitizeInput(data []byte) []byte {
// implementation
}
func parseData(data []byte) (Result, error) {
// implementation
}
func isValidFormat(input string) bool {
// implementation
}
func checkLength(input string) bool {
// implementation
}
```
### Test files:
```go
package user_test
// Public test functions first
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
user := createTestUser()
result, err := service.CreateUser(user)
assert.NoError(t, err)
assert.NotNil(t, result)
}
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
user := createTestUser()
// test implementation
}
// Private helper functions at the bottom
func createTestUser() *User {
return &User{
Name: "Test User",
Email: "test@example.com",
}
}
func setupTestDatabase() *Database {
// setup implementation
}
```
### Controller example:
```go
type ProjectController struct {
service *ProjectService
}
// Public HTTP handlers first
func (c *ProjectController) CreateProject(ctx *gin.Context) {
var request CreateProjectRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
c.handleError(ctx, err)
return
}
// handler logic
}
func (c *ProjectController) GetProject(ctx *gin.Context) {
projectID := c.extractProjectID(ctx)
// handler logic
}
// Private helper methods at the bottom
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
return uuid.MustParse(ctx.Param("projectId"))
}
```
## Key Points:
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
- This improves code readability by showing the public API first
- Private helpers are implementation details, so they go at the bottom
- Apply this rule consistently across ALL Go files in the project

View File

@@ -1,45 +0,0 @@
---
description:
globs:
alwaysApply: true
---
## Comment Guidelines
1. **No obvious comments** - Don't state what the code already clearly shows
2. **Functions and variables should have meaningful names** - Code should be self-documenting
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
## Key Principles:
- **Code should tell a story** - Use descriptive variable and function names
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
Example of useless comment:
1.
```sql
// Create projects table
CREATE TABLE projects (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
name TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
```
2.
```go
// Create test project
project := CreateTestProject(projectName, user, router)
```
3.
```go
// CreateValidLogItems creates valid log items for testing
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
```

View File

@@ -1,133 +0,0 @@
---
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
3. We document all routes with Swagger in the following format:
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,671 +0,0 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
```
package audit_logs
import (
"net/http"
user_models "databasus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
```
------ backend/internal/features/audit_logs/controller_test.go ------
```
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
"databasus-backend/internal/storage"
test_utils "databasus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
createAuditLog(service, "Test log with project", nil, &projectID)
createAuditLog(service, "Test log standalone", nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
assert.GreaterOrEqual(t, response.Total, int64(3))
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, "Test log with user")
assert.Contains(t, messages, "Test log with project")
assert.Contains(t, messages, "Test log standalone")
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
projectID := uuid.New()
// Create test logs for different users
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
createAuditLog(service, "Test project log", nil, &projectID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
assert.Equal(t, 2, len(user1Response.AuditLogs))
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, "Test log user1 first")
assert.Contains(t, messages, "Test log user1 second")
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
service := GetAuditLogService()
db := storage.GetDb()
baseTime := time.Now().UTC()
// Create logs with different timestamps
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
// Test filtering - get logs before 1 hour ago
beforeTime := baseTime.Add(-1 * time.Hour)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
// Verify only old log is returned
messages := extractMessages(filteredResponse.AuditLogs)
assert.Contains(t, messages, "Test old log")
assert.NotContains(t, messages, "Test recent log")
assert.NotContains(t, messages, "Test current log")
// Test without filter - should get all logs
var allResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}
```
------ backend/internal/features/audit_logs/di.go ------
```
package audit_logs
import (
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
```
------ backend/internal/features/audit_logs/dto.go ------
```
package audit_logs
import "time"
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLog `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
```
------ backend/internal/features/audit_logs/models.go ------
```
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}
```
------ backend/internal/features/audit_logs/repository.go ------
```
package audit_logs
import (
"databasus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("user_id = ?", userID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByProject(
projectID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLog, error) {
var auditLogs []*AuditLog
query := storage.GetDb().
Where("project_id = ?", projectID).
Order("created_at DESC")
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.
Limit(limit).
Offset(offset).
Find(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}
```
------ backend/internal/features/audit_logs/service.go ------
```
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "databasus-backend/internal/features/users/enums"
user_models "databasus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
projectID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
ProjectID: projectID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetProjectAuditLogs(
projectID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
```
------ backend/internal/features/audit_logs/service_test.go ------
```
package audit_logs
import (
"testing"
"time"
user_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
)
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
project1ID, project2ID := uuid.New(), uuid.New()
// Create test logs for projects
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
createAuditLog(service, "Test no project log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test project 1 logs
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project1Response.AuditLogs))
messages := extractMessages(project1Response.AuditLogs)
assert.Contains(t, messages, "Test project1 log first")
assert.Contains(t, messages, "Test project1 log second")
for _, log := range project1Response.AuditLogs {
assert.Equal(t, &project1ID, log.ProjectID)
}
// Test project 2 logs
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(project2Response.AuditLogs))
messages2 := extractMessages(project2Response.AuditLogs)
assert.Contains(t, messages2, "Test project2 log first")
assert.Contains(t, messages2, "Test project2 log second")
// Test pagination
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
}
}
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
service.WriteAuditLog(message, userID, projectID)
}
func extractMessages(logs []*AuditLog) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
log := &AuditLog{
ID: uuid.New(),
UserID: userID,
Message: message,
CreatedAt: createdAt,
}
db.Create(log)
}
```

View File

@@ -1,74 +0,0 @@
---
description:
globs:
alwaysApply: true
---
For DI files use implicit fields declaration styles (espesially
for controllers, services, repositories, use cases, etc., not simple
data structures).
So, instead of:
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
Use:
var orderController = &OrderController{
orderService,
bot_users.GetBotUserService(),
bots.GetBotService(),
users.GetUserService(),
}
This is needed to avoid forgetting to update DI style
when we add new dependency.
---
Please force such usage if file look like this (see some
services\controllers\repos definitions and getters):
var orderBackgroundService = &OrderBackgroundService{
orderService: orderService,
orderPaymentRepository: orderPaymentRepository,
botService: bots.GetBotService(),
paymentSettingsService: payment_settings.GetPaymentSettingsService(),
orderSubscriptionListeners: []OrderSubscriptionListener{},
}
var orderController = &OrderController{
orderService: orderService,
botUserService: bot_users.GetBotUserService(),
botService: bots.GetBotService(),
userService: users.GetUserService(),
}
func GetUniquePaymentRepository() *repositories.UniquePaymentRepository {
return uniquePaymentRepository
}
func GetOrderPaymentRepository() *repositories.OrderPaymentRepository {
return orderPaymentRepository
}
func GetOrderService() *OrderService {
return orderService
}
func GetOrderController() *OrderController {
return orderController
}
func GetOrderBackgroundService() *OrderBackgroundService {
return orderBackgroundService
}
func GetOrderRepository() *repositories.OrderRepository {
return orderRepository
}

View File

@@ -1,27 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When writting migrations:
- write them for PostgreSQL
- for PRIMARY UUID keys use gen_random_uuid()
- for time use TIMESTAMPTZ (timestamp with zone)
- split table, constraint and indexes declaration (table first, them other one by one)
- format SQL in pretty way (add spaces, align columns types), constraints split by lines. The example:
CREATE TABLE marketplace_info (
bot_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
title TEXT NOT NULL,
description TEXT NOT NULL,
short_description TEXT NOT NULL,
tutorial_url TEXT,
info_order BIGINT NOT NULL DEFAULT 0,
is_published BOOLEAN NOT NULL DEFAULT FALSE
);
ALTER TABLE marketplace_info_images
ADD CONSTRAINT fk_marketplace_info_images_bot_id
FOREIGN KEY (bot_id)
REFERENCES marketplace_info (bot_id);

View File

@@ -1,12 +0,0 @@
---
description:
globs:
alwaysApply: true
---
When applying changes, do not forget to refactor old code.
You can shortify, make more readable, improve code quality, etc.
Common logic can be extracted to functions, constants, files, etc.
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).

View File

@@ -1,147 +0,0 @@
---
description:
globs:
alwaysApply: true
---
After writing tests, always launch them and verify that they pass.
## Test Naming Format
Use these naming patterns:
- `Test_WhatWeDo_WhatWeExpect`
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
## Examples from Real Codebase:
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
## Testing Philosophy
**Prefer Controllers Over Unit Tests:**
- Test through HTTP endpoints via controllers whenever possible
- Avoid testing repositories, services in isolation - test via API instead
- Only use unit tests for complex model logic when no API exists
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
**Extract Common Logic to Testing Utilities:**
- Create `testing.go` or `testing/testing.go` files for shared test utilities
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
- Reuse common patterns across different test files
**Refactor Existing Tests:**
- When working with existing tests, always look for opportunities to refactor and improve
- Extract repetitive setup code to common utilities
- Simplify complex tests by breaking them into smaller, focused tests
- Replace inline test data creation with reusable helper functions
- Consolidate similar test patterns across different test files
- Make tests more readable and maintainable for other developers
## Testing Utilities Structure
**Create `testing.go` or `testing/testing.go` files with common utilities:**
```go
package projects_testing
// CreateTestRouter creates unified router for all controllers
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
for _, controller := range controllers {
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
controller.RegisterRoutes(routerGroup)
}
}
return router
}
// CreateTestProjectViaAPI creates project through HTTP API
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
request := projects_dto.CreateProjectRequestDTO{Name: name}
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
// Handle response...
return project, owner.Token
}
// AddMemberToProject adds member via API call
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
// Implementation...
}
```
## Controller Test Examples
**Permission-based testing:**
```go
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
var response ApiKey
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
assert.Equal(t, "Test API Key", response.Name)
assert.NotEmpty(t, response.Token)
}
```
**Cross-project security testing:**
```go
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
// Try to update via different project endpoint
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
}
```
**E2E lifecycle testing:**
```go
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
// 1. Create project
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
project := projects_testing.CreateTestProject("E2E Project", owner, router)
// 2. Add member
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
// 3. Promote to admin
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
// 4. Transfer ownership
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
// 5. Verify new owner can manage project
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
assert.Equal(t, project.ID, finalProject.ID)
}
```

View File

@@ -1,6 +0,0 @@
---
description:
globs:
alwaysApply: true
---
Always use time.Now().UTC() instead of time.Now()

View File

@@ -2,8 +2,18 @@
DEV_DB_NAME=databasus
DEV_DB_USERNAME=postgres
DEV_DB_PASSWORD=Q1234567
#app
# app
ENV_MODE=development
# logging
SHOW_DB_INSTALLATION_VERIFICATION_LOGS=true
VICTORIA_LOGS_URL=http://localhost:9428
VICTORIA_LOGS_PASSWORD=devpassword
# tests
TEST_LOCALHOST=localhost
IS_SKIP_EXTERNAL_RESOURCES_TESTS=false
# cloudflare turnstile
CLOUDFLARE_TURNSTILE_SITE_KEY=
CLOUDFLARE_TURNSTILE_SECRET_KEY=
# db
DATABASE_DSN=host=dev-db user=postgres password=Q1234567 dbname=databasus port=5437 sslmode=disable
DATABASE_URL=postgres://postgres:Q1234567@dev-db:5437/databasus?sslmode=disable

3
backend/.gitignore vendored
View File

@@ -18,4 +18,5 @@ pgdata-for-restore/
temp/
cmd.exe
temp/
valkey-data/
valkey-data/
victoria-logs-data/

View File

@@ -16,7 +16,6 @@ import (
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_download "databasus-backend/internal/features/backups/backups/download"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -26,8 +25,10 @@ import (
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/restores"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
system_healthcheck "databasus-backend/internal/features/system/healthcheck"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_controllers "databasus-backend/internal/features/users/controllers"
users_middleware "databasus-backend/internal/features/users/middleware"
users_services "databasus-backend/internal/features/users/services"
@@ -59,6 +60,8 @@ func main() {
cache_utils.TestCacheConnection()
if config.GetEnv().IsPrimaryNode {
log.Info("Clearing cache...")
err := cache_utils.ClearAllCache()
if err != nil {
log.Error("Failed to clear cache", "error", err)
@@ -182,6 +185,9 @@ func startServerWithGracefulShutdown(log *slog.Logger, app *gin.Engine) {
<-quit
log.Info("Shutdown signal received")
// Gracefully shutdown VictoriaLogs writer
logger.ShutdownVictoriaLogs(5 * time.Second)
// The context is used to inform the server it has 10 seconds to finish
// the request it is currently handling
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
@@ -239,7 +245,7 @@ func setUpDependencies() {
notifiers.SetupDependencies()
storages.SetupDependencies()
backups_config.SetupDependencies()
backups_cancellation.SetupDependencies()
task_cancellation.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {
@@ -257,20 +263,24 @@ func runBackgroundTasks(log *slog.Logger) {
cancel()
}()
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
if config.GetEnv().IsPrimaryNode {
log.Info("Starting primary node background tasks...")
err := files_utils.CleanFolder(config.GetEnv().TempFolder)
if err != nil {
log.Error("Failed to clean temp folder", "error", err)
}
go runWithPanicLogging(log, "backup background service", func() {
backuping.GetBackupsScheduler().Run(ctx)
})
go runWithPanicLogging(log, "backup cleaner background service", func() {
backuping.GetBackupCleaner().Run(ctx)
})
go runWithPanicLogging(log, "restore background service", func() {
restores.GetRestoreBackgroundService().Run(ctx)
restoring.GetRestoresScheduler().Run(ctx)
})
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
@@ -284,18 +294,30 @@ func runBackgroundTasks(log *slog.Logger) {
go runWithPanicLogging(log, "download token cleanup background service", func() {
backups_download.GetDownloadTokenBackgroundService().Run(ctx)
})
go runWithPanicLogging(log, "backup nodes registry background service", func() {
backuping.GetBackupNodesRegistry().Run(ctx)
})
go runWithPanicLogging(log, "restore nodes registry background service", func() {
restoring.GetRestoreNodesRegistry().Run(ctx)
})
} else {
log.Info("Skipping primary node tasks as not primary node")
}
if config.GetEnv().IsBackupNode {
if config.GetEnv().IsProcessingNode {
log.Info("Starting backup node background tasks...")
go runWithPanicLogging(log, "backup node", func() {
backuping.GetBackuperNode().Run(ctx)
})
go runWithPanicLogging(log, "restore node", func() {
restoring.GetRestorerNode().Run(ctx)
})
} else {
log.Info("Skipping backup node tasks as not backup node")
log.Info("Skipping backup/restore node tasks as not backup node")
}
}

View File

@@ -34,6 +34,20 @@ services:
retries: 5
start_period: 20s
# VictoriaLogs for external logging
victoria-logs:
image: victoriametrics/victoria-logs:latest
container_name: victoria-logs
ports:
- "9428:9428"
command:
- -storageDataPath=/victoria-logs-data
- -retentionPeriod=7d
- -httpAuth.password=devpassword
volumes:
- ./victoria-logs-data:/victoria-logs-data
restart: unless-stopped
# Test MinIO container
test-minio:
image: minio/minio:latest

View File

@@ -28,7 +28,6 @@ require (
github.com/valkey-io/valkey-go v1.0.70
go.mongodb.org/mongo-driver v1.17.6
golang.org/x/crypto v0.46.0
golang.org/x/time v0.14.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
@@ -186,6 +185,7 @@ require (
go.yaml.in/yaml/v2 v2.4.3 // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/term v0.38.0 // indirect
golang.org/x/time v0.14.0 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
gopkg.in/validator.v2 v2.0.1 // indirect
moul.io/http2curl/v2 v2.3.0 // indirect

View File

@@ -9,7 +9,6 @@ import (
"strings"
"sync"
"github.com/google/uuid"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
@@ -23,17 +22,30 @@ const (
type EnvVariables struct {
IsTesting bool
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
MongodbInstallDir string `env:"MONGODB_INSTALL_DIR"`
NodeID string
// Internal database
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
// Internal Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME" required:"true"`
ValkeyPassword string `env:"VALKEY_PASSWORD" required:"true"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
IsCloud bool `env:"IS_CLOUD"`
TestLocalhost string `env:"TEST_LOCALHOST"`
ShowDbInstallationVerificationLogs bool `env:"SHOW_DB_INSTALLATION_VERIFICATION_LOGS"`
IsSkipExternalResourcesTests bool `env:"IS_SKIP_EXTERNAL_RESOURCES_TESTS"`
IsManyNodesMode bool `env:"IS_MANY_NODES_MODE"`
IsPrimaryNode bool `env:"IS_PRIMARY_NODE"`
IsBackupNode bool `env:"IS_BACKUP_NODE"`
IsProcessingNode bool `env:"IS_PROCESSING_NODE"`
NodeNetworkThroughputMBs int `env:"NODE_NETWORK_THROUGHPUT_MBPS"`
DataFolder string
@@ -86,19 +98,16 @@ type EnvVariables struct {
TestMongodb70Port string `env:"TEST_MONGODB_70_PORT"`
TestMongodb82Port string `env:"TEST_MONGODB_82_PORT"`
// Valkey
ValkeyHost string `env:"VALKEY_HOST" required:"true"`
ValkeyPort string `env:"VALKEY_PORT" required:"true"`
ValkeyUsername string `env:"VALKEY_USERNAME"`
ValkeyPassword string `env:"VALKEY_PASSWORD"`
ValkeyIsSsl bool `env:"VALKEY_IS_SSL" required:"true"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// Cloudflare Turnstile
CloudflareTurnstileSecretKey string `env:"CLOUDFLARE_TURNSTILE_SECRET_KEY"`
CloudflareTurnstileSiteKey string `env:"CLOUDFLARE_TURNSTILE_SITE_KEY"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
@@ -109,6 +118,15 @@ type EnvVariables struct {
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
// SMTP configuration (optional)
SMTPHost string `env:"SMTP_HOST"`
SMTPPort int `env:"SMTP_PORT"`
SMTPUser string `env:"SMTP_USER"`
SMTPPassword string `env:"SMTP_PASSWORD"`
// Application URL (optional) - used for email links
DatabasusURL string `env:"DATABASUS_URL"`
}
var (
@@ -169,6 +187,21 @@ func loadEnvVariables() {
os.Exit(1)
}
// Set default value for ShowDbInstallationVerificationLogs if not defined
if os.Getenv("SHOW_DB_INSTALLATION_VERIFICATION_LOGS") == "" {
env.ShowDbInstallationVerificationLogs = true
}
// Set default value for IsSkipExternalTests if not defined
if os.Getenv("IS_SKIP_EXTERNAL_RESOURCES_TESTS") == "" {
env.IsSkipExternalResourcesTests = false
}
// Set default value for IsCloud if not defined
if os.Getenv("IS_CLOUD") == "" {
env.IsCloud = false
}
for _, arg := range os.Args {
if strings.Contains(arg, "test") {
env.IsTesting = true
@@ -176,6 +209,14 @@ func loadEnvVariables() {
}
}
// Check for external database override
if externalDsn := os.Getenv("DANGEROUS_EXTERNAL_DATABASE_DSN"); externalDsn != "" {
log.Warn(
"Using DANGEROUS_EXTERNAL_DATABASE_DSN - connecting to external database instead of internal PostgreSQL",
)
env.DatabaseDsn = externalDsn
}
if env.DatabaseDsn == "" {
log.Error("DATABASE_DSN is empty")
os.Exit(1)
@@ -192,25 +233,48 @@ func loadEnvVariables() {
log.Info("ENV_MODE loaded", "mode", env.EnvMode)
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
tools.VerifyPostgresesInstallation(
log,
env.EnvMode,
env.PostgresesInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
tools.VerifyMysqlInstallation(
log,
env.EnvMode,
env.MysqlInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
tools.VerifyMariadbInstallation(
log,
env.EnvMode,
env.MariadbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.MongodbInstallDir = filepath.Join(backendRoot, "tools", "mongodb")
tools.VerifyMongodbInstallation(log, env.EnvMode, env.MongodbInstallDir)
tools.VerifyMongodbInstallation(
log,
env.EnvMode,
env.MongodbInstallDir,
env.ShowDbInstallationVerificationLogs,
)
env.NodeID = uuid.New().String()
if env.NodeNetworkThroughputMBs == 0 {
env.NodeNetworkThroughputMBs = 125 // 1 Gbit/s
}
if !env.IsManyNodesMode {
env.IsPrimaryNode = true
env.IsBackupNode = true
env.IsProcessingNode = true
}
if env.TestLocalhost == "" {
env.TestLocalhost = "localhost"
}
// Valkey
@@ -223,6 +287,27 @@ func loadEnvVariables() {
os.Exit(1)
}
// Check for external Valkey override
if externalValkeyHost := os.Getenv("DANGEROUS_VALKEY_HOST"); externalValkeyHost != "" {
log.Warn(
"Using DANGEROUS_VALKEY_* variables - connecting to external Valkey instead of internal instance",
)
env.ValkeyHost = externalValkeyHost
if externalValkeyPort := os.Getenv("DANGEROUS_VALKEY_PORT"); externalValkeyPort != "" {
env.ValkeyPort = externalValkeyPort
}
if externalValkeyUsername := os.Getenv("DANGEROUS_VALKEY_USERNAME"); externalValkeyUsername != "" {
env.ValkeyUsername = externalValkeyUsername
}
if externalValkeyPassword := os.Getenv("DANGEROUS_VALKEY_PASSWORD"); externalValkeyPassword != "" {
env.ValkeyPassword = externalValkeyPassword
}
if externalValkeyIsSsl := os.Getenv("DANGEROUS_VALKEY_IS_SSL"); externalValkeyIsSsl != "" {
env.ValkeyIsSsl = externalValkeyIsSsl == "true"
}
}
// Store the data and temp folders one level below the root
// (projectRoot/databasus-data -> /databasus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "databasus-data", "backups")

View File

@@ -2,34 +2,50 @@ package audit_logs
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type AuditLogBackgroundService struct {
auditLogService *AuditLogService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *AuditLogBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting audit log cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
s.logger.Info("Starting audit log cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.cleanOldAuditLogs(); err != nil {
s.logger.Error("Failed to clean old audit logs", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -1,6 +1,9 @@
package audit_logs
import (
"sync"
"sync/atomic"
users_services "databasus-backend/internal/features/users/services"
"databasus-backend/internal/util/logger"
)
@@ -14,8 +17,10 @@ var auditLogController = &AuditLogController{
auditLogService,
}
var auditLogBackgroundService = &AuditLogBackgroundService{
auditLogService,
logger.GetLogger(),
auditLogService: auditLogService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
func GetAuditLogService() *AuditLogService {
@@ -30,8 +35,23 @@ func GetAuditLogBackgroundService() *AuditLogBackgroundService {
return auditLogBackgroundService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,23 +1,33 @@
package backuping
import (
"bytes"
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
"encoding/json"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
)
const (
heartbeatTickerInterval = 15 * time.Second
backuperHeathcheckThreshold = 5 * time.Minute
)
type BackuperNode struct {
@@ -28,77 +38,93 @@ type BackuperNode struct {
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
notificationSender backups_core.NotificationSender
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
backupCancelManager *tasks_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
logger *slog.Logger
createBackupUseCase backups_core.CreateBackupUsecase
nodeID uuid.UUID
lastHeartbeat time.Time
runOnce sync.Once
hasRun atomic.Bool
}
func (n *BackuperNode) Run(ctx context.Context) {
n.lastHeartbeat = time.Now().UTC()
wasAlreadyRun := n.hasRun.Load()
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
n.runOnce.Do(func() {
n.hasRun.Store(true)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
n.lastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
throughputMBs := config.GetEnv().NodeNetworkThroughputMBs
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
n.MakeBackup(backupID, isCallNotifier)
if err := n.nodesRegistry.PublishBackupCompletion(n.nodeID.String(), backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
backupNode := BackupNode{
ID: n.nodeID,
ThroughputMBs: throughputMBs,
LastHeartbeat: time.Now().UTC(),
}
}
if err := n.nodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID.String(), backupHandler); err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.nodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), backupNode); err != nil {
n.logger.Error("Failed to register node in registry", "error", err)
panic(err)
}
}()
ticker := time.NewTicker(15 * time.Second)
defer ticker.Stop()
backupHandler := func(backupID uuid.UUID, isCallNotifier bool) {
go func() {
n.MakeBackup(backupID, isCallNotifier)
if err := n.backupNodesRegistry.PublishBackupCompletion(n.nodeID, backupID); err != nil {
n.logger.Error(
"Failed to publish backup completion",
"error",
err,
"backupID",
backupID,
)
}
}()
}
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.nodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
err := n.backupNodesRegistry.SubscribeNodeForBackupsAssignment(n.nodeID, backupHandler)
if err != nil {
n.logger.Error("Failed to subscribe to backup assignments", "error", err)
panic(err)
}
defer func() {
if err := n.backupNodesRegistry.UnsubscribeNodeForBackupsAssignments(); err != nil {
n.logger.Error("Failed to unsubscribe from backup assignments", "error", err)
}
}()
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
ticker := time.NewTicker(heartbeatTickerInterval)
defer ticker.Stop()
n.logger.Info("Backup node started", "nodeID", n.nodeID, "throughput", throughputMBs)
for {
select {
case <-ctx.Done():
n.logger.Info("Shutdown signal received, unregistering node", "nodeID", n.nodeID)
if err := n.backupNodesRegistry.UnregisterNodeFromRegistry(backupNode); err != nil {
n.logger.Error("Failed to unregister node from registry", "error", err)
}
return
case <-ticker.C:
n.sendHeartbeat(&backupNode)
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", n))
}
}
func (n *BackuperNode) IsBackuperRunning() bool {
return n.lastHeartbeat.After(time.Now().UTC().Add(-5 * time.Minute))
return n.lastHeartbeat.After(time.Now().UTC().Add(-backuperHeathcheckThreshold))
}
func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
@@ -135,32 +161,86 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
start := time.Now().UTC()
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterTask(backup.ID, cancel)
defer n.backupCancelManager.UnregisterTask(backup.ID)
backupProgressListener := func(
completedMBs float64,
) {
backup.BackupSizeMb = completedMBs
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Check size limit (0 = unlimited)
if backupConfig.MaxBackupSizeMB > 0 &&
completedMBs > float64(backupConfig.MaxBackupSizeMB) {
errMsg := fmt.Sprintf(
"backup size (%.2f MB) exceeded maximum allowed size (%d MB)",
completedMBs,
backupConfig.MaxBackupSizeMB,
)
backup.Status = backups_core.BackupStatusFailed
backup.IsSkipRetry = true
backup.FailMessage = &errMsg
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to save backup with size exceeded error", "error", err)
}
cancel() // Cancel the backup context
return
}
if err := n.backupRepository.Save(backup); err != nil {
n.logger.Error("Failed to update backup progress", "error", err)
}
}
ctx, cancel := context.WithCancel(context.Background())
n.backupCancelManager.RegisterBackup(backup.ID, cancel)
defer n.backupCancelManager.UnregisterBackup(backup.ID)
backupMetadata, err := n.createBackupUseCase.Execute(
ctx,
backup.ID,
backup,
backupConfig,
database,
storage,
backupProgressListener,
)
if err != nil {
// Check if backup was already marked as failed by progress listener (e.g., size limit exceeded)
// If so, skip error handling to avoid overwriting the status
currentBackup, fetchErr := n.backupRepository.FindByID(backup.ID)
if fetchErr == nil && currentBackup.Status == backups_core.BackupStatusFailed {
n.logger.Warn(
"Backup already marked as failed by progress listener, skipping error handling",
"backupId",
backup.ID,
"failMessage",
*currentBackup.FailMessage,
)
// Still call notification for size limit failures
n.SendBackupNotification(
backupConfig,
currentBackup,
backups_config.NotificationBackupFailed,
currentBackup.FailMessage,
)
return
}
errMsg := err.Error()
// Log detailed error information for debugging
n.logger.Error("Backup execution failed",
"backupId", backup.ID,
"databaseId", databaseID,
"databaseType", database.Type,
"storageId", storage.ID,
"storageType", storage.Type,
"error", err,
"errorMessage", errMsg,
)
// Check if backup was cancelled (not due to shutdown)
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
strings.Contains(errMsg, "context canceled") ||
@@ -168,6 +248,12 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
isShutdown := strings.Contains(errMsg, "shutdown")
if isCancelled && !isShutdown {
n.logger.Warn("Backup was cancelled by user or system",
"backupId", backup.ID,
"isCancelled", isCancelled,
"isShutdown", isShutdown,
)
backup.Status = backups_core.BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
@@ -179,7 +265,7 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
// Delete partial backup from storage
storage, storageErr := n.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.ID); deleteErr != nil {
if deleteErr := storage.DeleteFile(n.fieldEncryptor, backup.FileName); deleteErr != nil {
n.logger.Error(
"Failed to delete partial backup file",
"backupId",
@@ -227,6 +313,13 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backupMetadata.BackupID = backup.ID
if err := backupMetadata.Validate(); err != nil {
n.logger.Error("Failed to validate backup metadata", "error", err)
return
}
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
@@ -237,6 +330,39 @@ func (n *BackuperNode) MakeBackup(backupID uuid.UUID, isCallNotifier bool) {
return
}
// Save metadata file to storage
if backupMetadata != nil {
metadataJSON, err := json.Marshal(backupMetadata)
if err != nil {
n.logger.Error("Failed to marshal backup metadata to JSON",
"backupId", backup.ID,
"error", err,
)
} else {
metadataReader := bytes.NewReader(metadataJSON)
metadataFileName := backup.FileName + ".metadata"
if err := storage.SaveFile(
context.Background(),
n.fieldEncryptor,
n.logger,
metadataFileName,
metadataReader,
); err != nil {
n.logger.Error("Failed to save backup metadata file to storage",
"backupId", backup.ID,
"fileName", metadataFileName,
"error", err,
)
} else {
n.logger.Info("Backup metadata file saved successfully",
"backupId", backup.ID,
"fileName", metadataFileName,
)
}
}
}
// Update database last backup time
now := time.Now().UTC()
if updateErr := n.databaseService.SetLastBackupTime(databaseID, now); updateErr != nil {
@@ -337,8 +463,7 @@ func (n *BackuperNode) SendBackupNotification(
func (n *BackuperNode) sendHeartbeat(backupNode *BackupNode) {
n.lastHeartbeat = time.Now().UTC()
backupNode.LastHeartbeat = time.Now().UTC()
if err := n.nodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
if err := n.backupNodesRegistry.HearthbeatNodeInRegistry(time.Now().UTC(), *backupNode); err != nil {
n.logger.Error("Failed to send heartbeat", "error", err)
}
}

View File

@@ -1,13 +1,10 @@
package backuping
import (
"context"
"errors"
"strings"
"testing"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -18,7 +15,6 @@ import (
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
@@ -158,35 +154,120 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
})
}
type CreateFailedBackupUsecase struct {
}
func Test_BackupSizeLimits(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
type CreateSuccessBackupUsecase struct{}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
t.Run("UnlimitedSize_MaxBackupSizeMBIsZero_BackupCompletes", func(t *testing.T) {
// Enable backups with unlimited size (0)
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 0 // unlimited
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateLargeBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully even with large size
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(10000), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
t.Run("SizeExceeded_BackupFailedWithIsSkipRetry", func(t *testing.T) {
// Enable backups with 5 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 5
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateProgressiveBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup was marked as failed with IsSkipRetry=true
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusFailed, updatedBackup.Status)
assert.True(t, updatedBackup.IsSkipRetry)
assert.NotNil(t, updatedBackup.FailMessage)
assert.Contains(t, *updatedBackup.FailMessage, "exceeded maximum allowed size")
assert.Contains(t, *updatedBackup.FailMessage, "10.00 MB")
assert.Contains(t, *updatedBackup.FailMessage, "5 MB")
assert.Greater(t, updatedBackup.BackupSizeMb, float64(5))
})
t.Run("SizeWithinLimit_BackupCompletes", func(t *testing.T) {
// Enable backups with 100 MB limit
backupConfig := backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backupConfig.MaxBackupSizeMB = 100
backupConfig, err := backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
backuperNode := CreateTestBackuperNode()
backuperNode.createBackupUseCase = &CreateMediumBackupUsecase{}
// Create a backup record
backup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(backup)
assert.NoError(t, err)
backuperNode.MakeBackup(backup.ID, false)
// Verify backup completed successfully
updatedBackup, err := backupRepository.FindByID(backup.ID)
assert.NoError(t, err)
assert.Equal(t, backups_core.BackupStatusCompleted, updatedBackup.Status)
assert.Equal(t, float64(50), updatedBackup.BackupSizeMb)
assert.Nil(t, updatedBackup.FailMessage)
})
}

View File

@@ -0,0 +1,461 @@
package backuping
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
util_encryption "databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
)
const (
cleanerTickerInterval = 1 * time.Minute
recentBackupGracePeriod = 60 * time.Minute
)
type BackupCleaner struct {
backupRepository *backups_core.BackupRepository
storageService *storages.StorageService
backupConfigService *backups_config.BackupConfigService
fieldEncryptor util_encryption.FieldEncryptor
logger *slog.Logger
backupRemoveListeners []backups_core.BackupRemoveListener
runOnce sync.Once
hasRun atomic.Bool
}
func (c *BackupCleaner) Run(ctx context.Context) {
wasAlreadyRun := c.hasRun.Load()
c.runOnce.Do(func() {
c.hasRun.Store(true)
if ctx.Err() != nil {
return
}
ticker := time.NewTicker(cleanerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := c.cleanByRetentionPolicy(); err != nil {
c.logger.Error("Failed to clean backups by retention policy", "error", err)
}
if err := c.cleanExceededBackups(); err != nil {
c.logger.Error("Failed to clean exceeded backups", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", c))
}
}
func (c *BackupCleaner) DeleteBackup(backup *backups_core.Backup) error {
for _, listener := range c.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := c.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(c.fieldEncryptor, backup.FileName)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error. It's possible that some S3 or
// storage is not available yet, it should not block us
c.logger.Error("Failed to delete backup file", "error", err)
}
metadataFileName := backup.FileName + ".metadata"
if err := storage.DeleteFile(c.fieldEncryptor, metadataFileName); err != nil {
c.logger.Error("Failed to delete backup metadata file", "error", err)
}
return c.backupRepository.DeleteByID(backup.ID)
}
func (c *BackupCleaner) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
c.backupRemoveListeners = append(c.backupRemoveListeners, listener)
}
func (c *BackupCleaner) cleanByRetentionPolicy() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
var cleanErr error
switch backupConfig.RetentionPolicyType {
case backups_config.RetentionPolicyTypeCount:
cleanErr = c.cleanByCount(backupConfig)
case backups_config.RetentionPolicyTypeGFS:
cleanErr = c.cleanByGFS(backupConfig)
default:
cleanErr = c.cleanByTimePeriod(backupConfig)
}
if cleanErr != nil {
c.logger.Error(
"Failed to clean backups by retention policy",
"databaseId", backupConfig.DatabaseID,
"policy", backupConfig.RetentionPolicyType,
"error", cleanErr,
)
}
}
return nil
}
func (c *BackupCleaner) cleanExceededBackups() error {
enabledBackupConfigs, err := c.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
if backupConfig.MaxBackupsTotalSizeMB <= 0 {
continue
}
if err := c.cleanExceededBackupsForDatabase(
backupConfig.DatabaseID,
backupConfig.MaxBackupsTotalSizeMB,
); err != nil {
c.logger.Error(
"Failed to clean exceeded backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
}
return nil
}
func (c *BackupCleaner) cleanByTimePeriod(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionTimePeriod == "" {
return nil
}
if backupConfig.RetentionTimePeriod == period.PeriodForever {
return nil
}
storeDuration := backupConfig.RetentionTimePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := c.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
return fmt.Errorf(
"failed to find old backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
for _, backup := range oldBackups {
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
c.logger.Info(
"Deleted old backup",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
}
return nil
}
func (c *BackupCleaner) cleanByCount(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionCount <= 0 {
return nil
}
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
backupConfig.DatabaseID,
backups_core.BackupStatusCompleted,
)
if err != nil {
return fmt.Errorf(
"failed to find completed backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
// completedBackups are ordered newest first; delete everything beyond position RetentionCount
if len(completedBackups) <= backupConfig.RetentionCount {
return nil
}
toDelete := completedBackups[backupConfig.RetentionCount:]
for _, backup := range toDelete {
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by count policy",
"backupId",
backup.ID,
"error",
err,
)
continue
}
c.logger.Info(
"Deleted backup by count policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
"retentionCount", backupConfig.RetentionCount,
)
}
return nil
}
func (c *BackupCleaner) cleanByGFS(backupConfig *backups_config.BackupConfig) error {
if backupConfig.RetentionGfsHours <= 0 && backupConfig.RetentionGfsDays <= 0 &&
backupConfig.RetentionGfsWeeks <= 0 && backupConfig.RetentionGfsMonths <= 0 &&
backupConfig.RetentionGfsYears <= 0 {
return nil
}
completedBackups, err := c.backupRepository.FindByDatabaseIdAndStatus(
backupConfig.DatabaseID,
backups_core.BackupStatusCompleted,
)
if err != nil {
return fmt.Errorf(
"failed to find completed backups for database %s: %w",
backupConfig.DatabaseID,
err,
)
}
keepSet := buildGFSKeepSet(
completedBackups,
backupConfig.RetentionGfsHours,
backupConfig.RetentionGfsDays,
backupConfig.RetentionGfsWeeks,
backupConfig.RetentionGfsMonths,
backupConfig.RetentionGfsYears,
)
for _, backup := range completedBackups {
if keepSet[backup.ID] {
continue
}
if isRecentBackup(backup) {
continue
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete backup by GFS policy",
"backupId",
backup.ID,
"error",
err,
)
continue
}
c.logger.Info(
"Deleted backup by GFS policy",
"backupId", backup.ID,
"databaseId", backupConfig.DatabaseID,
)
}
return nil
}
func (c *BackupCleaner) cleanExceededBackupsForDatabase(
databaseID uuid.UUID,
limitperDbMB int64,
) error {
for {
backupsTotalSizeMB, err := c.backupRepository.GetTotalSizeByDatabase(databaseID)
if err != nil {
return err
}
if backupsTotalSizeMB <= float64(limitperDbMB) {
break
}
oldestBackups, err := c.backupRepository.FindOldestByDatabaseExcludingInProgress(
databaseID,
1,
)
if err != nil {
return err
}
if len(oldestBackups) == 0 {
c.logger.Warn(
"No backups to delete but still over limit",
"databaseId",
databaseID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
backup := oldestBackups[0]
if isRecentBackup(backup) {
c.logger.Warn(
"Oldest backup is too recent to delete, stopping size cleanup",
"databaseId",
databaseID,
"backupId",
backup.ID,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
break
}
if err := c.DeleteBackup(backup); err != nil {
c.logger.Error(
"Failed to delete exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"error",
err,
)
return err
}
c.logger.Info(
"Deleted exceeded backup",
"backupId",
backup.ID,
"databaseId",
databaseID,
"backupSizeMB",
backup.BackupSizeMb,
"totalSizeMB",
backupsTotalSizeMB,
"limitMB",
limitperDbMB,
)
}
return nil
}
func isRecentBackup(backup *backups_core.Backup) bool {
return time.Since(backup.CreatedAt) < recentBackupGracePeriod
}
// buildGFSKeepSet determines which backups to retain under the GFS rotation scheme.
// Backups must be sorted newest-first. A backup can fill multiple slots simultaneously
// (e.g. the newest backup of a year also fills the monthly, weekly, daily, and hourly slot).
func buildGFSKeepSet(
backups []*backups_core.Backup,
hours, days, weeks, months, years int,
) map[uuid.UUID]bool {
keep := make(map[uuid.UUID]bool)
hoursSeen := make(map[string]bool)
daysSeen := make(map[string]bool)
weeksSeen := make(map[string]bool)
monthsSeen := make(map[string]bool)
yearsSeen := make(map[string]bool)
hoursKept, daysKept, weeksKept, monthsKept, yearsKept := 0, 0, 0, 0, 0
for _, backup := range backups {
t := backup.CreatedAt
hourKey := t.Format("2006-01-02-15")
dayKey := t.Format("2006-01-02")
weekYear, week := t.ISOWeek()
weekKey := fmt.Sprintf("%d-%02d", weekYear, week)
monthKey := t.Format("2006-01")
yearKey := t.Format("2006")
if hours > 0 && hoursKept < hours && !hoursSeen[hourKey] {
keep[backup.ID] = true
hoursSeen[hourKey] = true
hoursKept++
}
if days > 0 && daysKept < days && !daysSeen[dayKey] {
keep[backup.ID] = true
daysSeen[dayKey] = true
daysKept++
}
if weeks > 0 && weeksKept < weeks && !weeksSeen[weekKey] {
keep[backup.ID] = true
weeksSeen[weekKey] = true
weeksKept++
}
if months > 0 && monthsKept < months && !monthsSeen[monthKey] {
keep[backup.ID] = true
monthsSeen[monthKey] = true
monthsKept++
}
if years > 0 && yearsKept < years && !yearsSeen[yearKey] {
keep[backup.ID] = true
yearsSeen[yearKey] = true
yearsKept++
}
}
return keep
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,43 +1,52 @@
package backuping
import (
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
backups_core "databasus-backend/internal/features/backups/backups/core"
"databasus-backend/internal/features/backups/backups/usecases"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
"time"
"github.com/google/uuid"
)
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = tasks_cancellation.GetTaskCancelManager()
var nodesRegistry = &BackupNodesRegistry{
var backupCleaner = &BackupCleaner{
backupRepository,
storages.GetStorageService(),
backups_config.GetBackupConfigService(),
encryption.GetFieldEncryptor(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
sync.Once{},
atomic.Bool{},
}
var backupNodesRegistry = &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
sync.Once{},
atomic.Bool{},
}
func getNodeID() uuid.UUID {
nodeIDStr := config.GetEnv().NodeID
nodeID, err := uuid.Parse(nodeIDStr)
if err != nil {
logger.GetLogger().Error("Failed to parse node ID from config", "error", err)
panic(err)
}
return nodeID
return uuid.New()
}
var backuperNode = &BackuperNode{
@@ -48,24 +57,28 @@ var backuperNode = &BackuperNode{
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
taskCancelManager,
backupNodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
getNodeID(),
time.Time{},
sync.Once{},
atomic.Bool{},
}
var backupsScheduler = &BackupsScheduler{
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
backupCancelManager,
nodesRegistry,
taskCancelManager,
backupNodesRegistry,
databases.GetDatabaseService(),
time.Now().UTC(),
logger.GetLogger(),
make(map[uuid.UUID]BackupToNodeRelation),
backuperNode,
sync.Once{},
atomic.Bool{},
}
func GetBackupsScheduler() *BackupsScheduler {
@@ -75,3 +88,11 @@ func GetBackupsScheduler() *BackupsScheduler {
func GetBackuperNode() *BackuperNode {
return backuperNode
}
func GetBackupNodesRegistry() *BackupNodesRegistry {
return backupNodesRegistry
}
func GetBackupCleaner() *BackupCleaner {
return backupCleaner
}

View File

@@ -6,6 +6,11 @@ import (
"github.com/google/uuid"
)
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
}
type BackupNode struct {
ID uuid.UUID `json:"id"`
ThroughputMBs int `json:"throughputMBs"`
@@ -18,17 +23,12 @@ type BackupNodeStats struct {
}
type BackupSubmitMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
IsCallNotifier bool `json:"isCallNotifier"`
}
type BackupCompletionMessage struct {
NodeID string `json:"nodeId"`
BackupID string `json:"backupId"`
}
type BackupToNodeRelation struct {
NodeID uuid.UUID `json:"nodeId"`
BackupsIDs []uuid.UUID `json:"backupsIds"`
NodeID uuid.UUID `json:"nodeId"`
BackupID uuid.UUID `json:"backupId"`
}

View File

@@ -1,8 +1,19 @@
package backuping
import (
"databasus-backend/internal/features/notifiers"
"context"
"errors"
"sync/atomic"
"time"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
)
@@ -17,3 +28,168 @@ func (m *MockNotificationSender) SendNotification(
) {
m.Called(notifier, title, message)
}
type CreateFailedBackupUsecase struct{}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct{}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateLargeBackupUsecase simulates a large backup (10000 MB)
type CreateLargeBackupUsecase struct{}
func (uc *CreateLargeBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(10000)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateProgressiveBackupUsecase simulates progressive size updates that exceed limit
type CreateProgressiveBackupUsecase struct{}
func (uc *CreateProgressiveBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
// Simulate progressive backup that grows beyond limit
backupProgressListener(1)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(3)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(5)
if ctx.Err() != nil {
return nil, ctx.Err()
}
backupProgressListener(10) // This exceeds the 5 MB limit
if ctx.Err() != nil {
return nil, ctx.Err()
}
// Should not reach here due to cancellation
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// CreateMediumBackupUsecase simulates a 50 MB backup
type CreateMediumBackupUsecase struct{}
func (uc *CreateMediumBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
backupProgressListener(50)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
// MockTrackingBackupUsecase tracks backup use case calls for testing parallel execution
type MockTrackingBackupUsecase struct {
callCount atomic.Int32
calledBackupIDs chan uuid.UUID
}
func NewMockTrackingBackupUsecase() *MockTrackingBackupUsecase {
return &MockTrackingBackupUsecase{
calledBackupIDs: make(chan uuid.UUID, 10),
}
}
func (m *MockTrackingBackupUsecase) Execute(
ctx context.Context,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
backupProgressListener func(completedMBs float64),
) (*common.BackupMetadata, error) {
m.callCount.Add(1)
// Send backup ID to channel (non-blocking)
select {
case m.calledBackupIDs <- backup.ID:
default:
}
// Simulate backup work
time.Sleep(100 * time.Millisecond)
backupProgressListener(10)
return &common.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}
func (m *MockTrackingBackupUsecase) GetCallCount() int32 {
return m.callCount.Load()
}
func (m *MockTrackingBackupUsecase) GetCalledBackupIDs() []uuid.UUID {
ids := []uuid.UUID{}
for {
select {
case id := <-m.calledBackupIDs:
ids = append(ids, id)
default:
return ids
}
}
}

View File

@@ -6,6 +6,8 @@ import (
"fmt"
"log/slog"
"strings"
"sync"
"sync/atomic"
"time"
cache_utils "databasus-backend/internal/util/cache"
@@ -15,25 +17,70 @@ import (
)
const (
nodeInfoKeyPrefix = "node:"
nodeInfoKeyPrefix = "backup:node:"
nodeInfoKeySuffix = ":info"
nodeActiveBackupsPrefix = "node:"
nodeActiveBackupsPrefix = "backup:node:"
nodeActiveBackupsSuffix = ":active_backups"
backupSubmitChannel = "backup:submit"
backupCompletionChannel = "backup:completion"
deadNodeThreshold = 2 * time.Minute
cleanupTickerInterval = 1 * time.Second
)
// BackupNodesRegistry helps to sync backups scheduler and backup nodes.
//
// Features:
// - Track node availability and load level
// - Assign from scheduler to node backups needed to be processed
// - Notify scheduler from node about backup completion
//
// Important things to remember:
// - Nodes without heartbeat for more than 2 minutes are not included
// in available nodes list and stats
//
// Cleanup dead nodes performed on 2 levels:
// - List and stats functions do not return dead nodes
// - Periodically dead nodes are cleaned up in cache (to not
// accumulate too many dead nodes in cache)
type BackupNodesRegistry struct {
client valkey.Client
logger *slog.Logger
timeout time.Duration
pubsubBackups *cache_utils.PubSubManager
pubsubCompletions *cache_utils.PubSubManager
runOnce sync.Once
hasRun atomic.Bool
}
func (r *BackupNodesRegistry) Run(ctx context.Context) {
wasAlreadyRun := r.hasRun.Load()
r.runOnce.Do(func() {
r.hasRun.Store(true)
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes on startup", "error", err)
}
ticker := time.NewTicker(cleanupTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := r.cleanupDeadNodes(); err != nil {
r.logger.Error("Failed to cleanup dead nodes", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", r))
}
}
func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
@@ -76,13 +123,30 @@ func (r *BackupNodesRegistry) GetAvailableNodes() ([]BackupNode, error) {
return nil, fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var nodes []BackupNode
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
nodes = append(nodes, node)
}
@@ -129,18 +193,54 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
return nil, fmt.Errorf("failed to pipeline get active backups keys: %w", err)
}
var stats []BackupNodeStats
for key, data := range keyDataMap {
var nodeInfoKeys []string
nodeIDToStatsKey := make(map[string]string)
for key := range keyDataMap {
nodeID := r.extractNodeIDFromKey(key, nodeActiveBackupsPrefix, nodeActiveBackupsSuffix)
nodeIDStr := nodeID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeIDStr, nodeInfoKeySuffix)
nodeInfoKeys = append(nodeInfoKeys, infoKey)
nodeIDToStatsKey[infoKey] = key
}
count, err := r.parseIntFromBytes(data)
nodeInfoMap, err := r.pipelineGetKeys(nodeInfoKeys)
if err != nil {
return nil, fmt.Errorf("failed to pipeline get node info keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var stats []BackupNodeStats
for infoKey, nodeData := range nodeInfoMap {
// Skip if the info key doesn't exist (nodeData is empty)
if len(nodeData) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(nodeData, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data", "key", infoKey, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
continue
}
statsKey := nodeIDToStatsKey[infoKey]
tasksData := keyDataMap[statsKey]
count, err := r.parseIntFromBytes(tasksData)
if err != nil {
r.logger.Warn("Failed to parse active backups count", "key", key, "error", err)
r.logger.Warn("Failed to parse active backups count", "key", statsKey, "error", err)
continue
}
stat := BackupNodeStats{
ID: nodeID,
ID: node.ID,
ActiveBackups: int(count),
}
stats = append(stats, stat)
@@ -149,11 +249,11 @@ func (r *BackupNodesRegistry) GetBackupNodesStats() ([]BackupNodeStats, error) {
return stats, nil
}
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Incr().Key(key).Build())
if result.Error() != nil {
@@ -167,11 +267,11 @@ func (r *BackupNodesRegistry) IncrementBackupsInProgress(nodeID string) error {
return nil
}
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID uuid.UUID) error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID, nodeActiveBackupsSuffix)
key := fmt.Sprintf("%s%s%s", nodeActiveBackupsPrefix, nodeID.String(), nodeActiveBackupsSuffix)
result := r.client.Do(ctx, r.client.B().Decr().Key(key).Build())
if result.Error() != nil {
@@ -198,6 +298,10 @@ func (r *BackupNodesRegistry) DecrementBackupsInProgress(nodeID string) error {
}
func (r *BackupNodesRegistry) HearthbeatNodeInRegistry(now time.Time, backupNode BackupNode) error {
if now.IsZero() {
return fmt.Errorf("cannot register node with zero heartbeat timestamp")
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
@@ -247,7 +351,7 @@ func (r *BackupNodesRegistry) UnregisterNodeFromRegistry(backupNode BackupNode)
}
func (r *BackupNodesRegistry) AssignBackupToNode(
targetNodeID string,
targetNodeID uuid.UUID,
backupID uuid.UUID,
isCallNotifier bool,
) error {
@@ -255,7 +359,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
message := BackupSubmitMessage{
NodeID: targetNodeID,
BackupID: backupID.String(),
BackupID: backupID,
IsCallNotifier: isCallNotifier,
}
@@ -273,7 +377,7 @@ func (r *BackupNodesRegistry) AssignBackupToNode(
}
func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
nodeID string,
nodeID uuid.UUID,
handler func(backupID uuid.UUID, isCallNotifier bool),
) error {
ctx := context.Background()
@@ -289,19 +393,7 @@ func (r *BackupNodesRegistry) SubscribeNodeForBackupsAssignment(
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(backupID, msg.IsCallNotifier)
handler(msg.BackupID, msg.IsCallNotifier)
}
err := r.pubsubBackups.Subscribe(ctx, backupSubmitChannel, wrappedHandler)
@@ -323,12 +415,12 @@ func (r *BackupNodesRegistry) UnsubscribeNodeForBackupsAssignments() error {
return nil
}
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uuid.UUID) error {
func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID uuid.UUID, backupID uuid.UUID) error {
ctx := context.Background()
message := BackupCompletionMessage{
NodeID: nodeID,
BackupID: backupID.String(),
BackupID: backupID,
}
messageJSON, err := json.Marshal(message)
@@ -345,7 +437,7 @@ func (r *BackupNodesRegistry) PublishBackupCompletion(nodeID string, backupID uu
}
func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
handler func(nodeID string, backupID uuid.UUID),
handler func(nodeID uuid.UUID, backupID uuid.UUID),
) error {
ctx := context.Background()
@@ -356,19 +448,7 @@ func (r *BackupNodesRegistry) SubscribeForBackupsCompletions(
return
}
backupID, err := uuid.Parse(msg.BackupID)
if err != nil {
r.logger.Warn(
"Failed to parse backup ID from completion message",
"backupId",
msg.BackupID,
"error",
err,
)
return
}
handler(msg.NodeID, backupID)
handler(msg.NodeID, msg.BackupID)
}
err := r.pubsubCompletions.Subscribe(ctx, backupCompletionChannel, wrappedHandler)
@@ -446,3 +526,108 @@ func (r *BackupNodesRegistry) parseIntFromBytes(data []byte) (int64, error) {
}
return count, nil
}
func (r *BackupNodesRegistry) cleanupDeadNodes() error {
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
var allKeys []string
cursor := uint64(0)
pattern := nodeInfoKeyPrefix + "*" + nodeInfoKeySuffix
for {
result := r.client.Do(
ctx,
r.client.B().Scan().Cursor(cursor).Match(pattern).Count(1_000).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to scan node keys: %w", result.Error())
}
scanResult, err := result.AsScanEntry()
if err != nil {
return fmt.Errorf("failed to parse scan result: %w", err)
}
allKeys = append(allKeys, scanResult.Elements...)
cursor = scanResult.Cursor
if cursor == 0 {
break
}
}
if len(allKeys) == 0 {
return nil
}
keyDataMap, err := r.pipelineGetKeys(allKeys)
if err != nil {
return fmt.Errorf("failed to pipeline get node keys: %w", err)
}
threshold := time.Now().UTC().Add(-deadNodeThreshold)
var deadNodeKeys []string
for key, data := range keyDataMap {
// Skip if the key doesn't exist (data is empty)
if len(data) == 0 {
continue
}
var node BackupNode
if err := json.Unmarshal(data, &node); err != nil {
r.logger.Warn("Failed to unmarshal node data during cleanup", "key", key, "error", err)
continue
}
// Skip nodes with zero/uninitialized heartbeat
if node.LastHeartbeat.IsZero() {
continue
}
if node.LastHeartbeat.Before(threshold) {
nodeID := node.ID.String()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, nodeID, nodeInfoKeySuffix)
statsKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
nodeID,
nodeActiveBackupsSuffix,
)
deadNodeKeys = append(deadNodeKeys, infoKey, statsKey)
r.logger.Info(
"Marking node for cleanup",
"nodeID", nodeID,
"lastHeartbeat", node.LastHeartbeat,
"threshold", threshold,
)
}
}
if len(deadNodeKeys) == 0 {
return nil
}
delCtx, delCancel := context.WithTimeout(context.Background(), r.timeout)
defer delCancel()
result := r.client.Do(
delCtx,
r.client.B().Del().Key(deadNodeKeys...).Build(),
)
if result.Error() != nil {
return fmt.Errorf("failed to delete dead node keys: %w", result.Error())
}
deletedCount, err := result.AsInt64()
if err != nil {
return fmt.Errorf("failed to parse deleted count: %w", err)
}
r.logger.Info("Cleaned up dead nodes", "deletedKeysCount", deletedCount)
return nil
}

View File

@@ -2,6 +2,10 @@ package backuping
import (
"context"
"encoding/json"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -36,7 +40,7 @@ func Test_UnregisterNodeFromRegistry_RemovesNodeAndCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.UnregisterNodeFromRegistry(node)
@@ -100,7 +104,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -109,7 +113,7 @@ func Test_IncrementBackupsInProgress_IncrementsCounter(t *testing.T) {
assert.Equal(t, node.ID, stats[0].ID)
assert.Equal(t, 1, stats[0].ActiveBackups)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -127,25 +131,25 @@ func Test_DecrementBackupsInProgress_DecrementsCounter(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node.ID.String())
err = registry.IncrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 3, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Equal(t, 2, stats[0].ActiveBackups)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -162,7 +166,7 @@ func Test_DecrementBackupsInProgress_WhenNegative_ResetsToZero(t *testing.T) {
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node)
assert.NoError(t, err)
err = registry.DecrementBackupsInProgress(node.ID.String())
err = registry.DecrementBackupsInProgress(node.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -188,19 +192,19 @@ func Test_GetBackupNodesStats_ReturnsStatsForAllNodes(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID.String())
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -274,12 +278,12 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID.String())
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID.String())
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
stats, err := registry.GetBackupNodesStats()
@@ -294,7 +298,7 @@ func Test_BackupCounters_TrackedSeparatelyPerNode(t *testing.T) {
assert.Equal(t, 2, statsMap[node1.ID])
assert.Equal(t, 1, statsMap[node2.ID])
err = registry.DecrementBackupsInProgress(node1.ID.String())
err = registry.DecrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
stats, err = registry.GetBackupNodesStats()
@@ -366,13 +370,239 @@ func Test_HearthbeatNodeInRegistry_UpdatesLastHeartbeat(t *testing.T) {
assert.True(t, nodes[0].LastHeartbeat.After(originalHeartbeat))
}
func Test_HearthbeatNodeInRegistry_RejectsZeroTimestamp(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node := createTestBackupNode()
err := registry.HearthbeatNodeInRegistry(time.Time{}, node)
assert.Error(t, err)
assert.Contains(t, err.Error(), "zero heartbeat timestamp")
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 0)
}
func Test_GetAvailableNodes_ExcludesStaleNodesFromCache(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 2)
nodeIDs := make(map[uuid.UUID]bool)
for _, n := range nodes {
nodeIDs[n.ID] = true
}
assert.True(t, nodeIDs[node1.ID])
assert.False(t, nodeIDs[node2.ID])
assert.True(t, nodeIDs[node3.ID])
}
func Test_GetBackupNodesStats_ExcludesStaleNodesFromCache(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
node3 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
defer cleanupTestNode(registry, node3)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node3)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node3.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 2)
statsMap := make(map[uuid.UUID]int)
for _, stat := range stats {
statsMap[stat.ID] = stat.ActiveBackups
}
assert.Equal(t, 1, statsMap[node1.ID])
_, hasNode2 := statsMap[node2.ID]
assert.False(t, hasNode2)
assert.Equal(t, 1, statsMap[node3.ID])
}
func Test_CleanupDeadNodes_RemovesNodeInfoAndCounter(t *testing.T) {
cache_utils.ClearAllCache()
registry := createTestRegistry()
node1 := createTestBackupNode()
node2 := createTestBackupNode()
defer cleanupTestNode(registry, node1)
defer cleanupTestNode(registry, node2)
err := registry.HearthbeatNodeInRegistry(time.Now().UTC(), node1)
assert.NoError(t, err)
err = registry.HearthbeatNodeInRegistry(time.Now().UTC(), node2)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node1.ID)
assert.NoError(t, err)
err = registry.IncrementBackupsInProgress(node2.ID)
assert.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), registry.timeout)
defer cancel()
key := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
result := registry.client.Do(ctx, registry.client.B().Get().Key(key).Build())
assert.NoError(t, result.Error())
data, err := result.AsBytes()
assert.NoError(t, err)
var node BackupNode
err = json.Unmarshal(data, &node)
assert.NoError(t, err)
node.LastHeartbeat = time.Now().UTC().Add(-3 * time.Minute)
modifiedData, err := json.Marshal(node)
assert.NoError(t, err)
setCtx, setCancel := context.WithTimeout(context.Background(), registry.timeout)
defer setCancel()
setResult := registry.client.Do(
setCtx,
registry.client.B().Set().Key(key).Value(string(modifiedData)).Build(),
)
assert.NoError(t, setResult.Error())
err = registry.cleanupDeadNodes()
assert.NoError(t, err)
checkCtx, checkCancel := context.WithTimeout(context.Background(), registry.timeout)
defer checkCancel()
infoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node2.ID.String(), nodeInfoKeySuffix)
infoResult := registry.client.Do(checkCtx, registry.client.B().Get().Key(infoKey).Build())
assert.Error(t, infoResult.Error())
counterKey := fmt.Sprintf(
"%s%s%s",
nodeActiveBackupsPrefix,
node2.ID.String(),
nodeActiveBackupsSuffix,
)
counterCtx, counterCancel := context.WithTimeout(context.Background(), registry.timeout)
defer counterCancel()
counterResult := registry.client.Do(
counterCtx,
registry.client.B().Get().Key(counterKey).Build(),
)
assert.Error(t, counterResult.Error())
activeInfoKey := fmt.Sprintf("%s%s%s", nodeInfoKeyPrefix, node1.ID.String(), nodeInfoKeySuffix)
activeCtx, activeCancel := context.WithTimeout(context.Background(), registry.timeout)
defer activeCancel()
activeResult := registry.client.Do(
activeCtx,
registry.client.B().Get().Key(activeInfoKey).Build(),
)
assert.NoError(t, activeResult.Error())
nodes, err := registry.GetAvailableNodes()
assert.NoError(t, err)
assert.Len(t, nodes, 1)
assert.Equal(t, node1.ID, nodes[0].ID)
stats, err := registry.GetBackupNodesStats()
assert.NoError(t, err)
assert.Len(t, stats, 1)
assert.Equal(t, node1.ID, stats[0].ID)
}
func createTestRegistry() *BackupNodesRegistry {
return &BackupNodesRegistry{
cache_utils.GetValkeyClient(),
logger.GetLogger(),
cache_utils.DefaultCacheTimeout,
cache_utils.NewPubSubManager(),
cache_utils.NewPubSubManager(),
client: cache_utils.GetValkeyClient(),
logger: logger.GetLogger(),
timeout: cache_utils.DefaultCacheTimeout,
pubsubBackups: cache_utils.NewPubSubManager(),
pubsubCompletions: cache_utils.NewPubSubManager(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -394,7 +624,7 @@ func Test_AssignBackupTonode_PublishesJsonMessageToChannel(t *testing.T) {
node := createTestBackupNode()
backupID := uuid.New()
err := registry.AssignBackupToNode(node.ID.String(), backupID, true)
err := registry.AssignBackupToNode(node.ID, backupID, true)
assert.NoError(t, err)
}
@@ -410,12 +640,12 @@ func Test_SubscribeNodeForBackupsAssignment_ReceivesSubmittedBackupsForMatchingN
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID, true)
err = registry.AssignBackupToNode(node.ID, backupID, true)
assert.NoError(t, err)
select {
@@ -439,12 +669,12 @@ func Test_SubscribeNodeForBackupsAssignment_FiltersOutBackupsForDifferentNode(t
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node1.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node2.ID.String(), backupID, false)
err = registry.AssignBackupToNode(node2.ID, backupID, false)
assert.NoError(t, err)
select {
@@ -467,15 +697,15 @@ func Test_SubscribeNodeForBackupsAssignment_ParsesJsonAndBackupIdCorrectly(t *te
receivedBackups <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
err = registry.AssignBackupToNode(node.ID, backupID1, true)
assert.NoError(t, err)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
err = registry.AssignBackupToNode(node.ID, backupID2, false)
assert.NoError(t, err)
received1 := <-receivedBackups
@@ -497,7 +727,7 @@ func Test_SubscribeNodeForBackupsAssignment_HandlesInvalidJson(t *testing.T) {
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
@@ -525,12 +755,12 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
receivedBackupID <- id
}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID1, true)
err = registry.AssignBackupToNode(node.ID, backupID1, true)
assert.NoError(t, err)
received := <-receivedBackupID
@@ -541,7 +771,7 @@ func Test_UnsubscribeNodeForBackupsAssignments_StopsReceivingMessages(t *testing
time.Sleep(100 * time.Millisecond)
err = registry.AssignBackupToNode(node.ID.String(), backupID2, false)
err = registry.AssignBackupToNode(node.ID, backupID2, false)
assert.NoError(t, err)
select {
@@ -559,10 +789,10 @@ func Test_SubscribeNodeForBackupsAssignment_WhenAlreadySubscribed_ReturnsError(t
handler := func(id uuid.UUID, isCallNotifier bool) {}
err := registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err := registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.NoError(t, err)
err = registry.SubscribeNodeForBackupsAssignment(node.ID.String(), handler)
err = registry.SubscribeNodeForBackupsAssignment(node.ID, handler)
assert.Error(t, err)
assert.Contains(t, err.Error(), "already subscribed")
}
@@ -593,25 +823,25 @@ func Test_MultipleNodes_EachReceivesOnlyTheirBackups(t *testing.T) {
handler2 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups2 <- id }
handler3 := func(id uuid.UUID, isCallNotifier bool) { receivedBackups3 <- id }
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID.String(), handler1)
err := registry1.SubscribeNodeForBackupsAssignment(node1.ID, handler1)
assert.NoError(t, err)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID.String(), handler2)
err = registry2.SubscribeNodeForBackupsAssignment(node2.ID, handler2)
assert.NoError(t, err)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID.String(), handler3)
err = registry3.SubscribeNodeForBackupsAssignment(node3.ID, handler3)
assert.NoError(t, err)
time.Sleep(100 * time.Millisecond)
submitRegistry := createTestRegistry()
err = submitRegistry.AssignBackupToNode(node1.ID.String(), backupID1, true)
err = submitRegistry.AssignBackupToNode(node1.ID, backupID1, true)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node2.ID.String(), backupID2, false)
err = submitRegistry.AssignBackupToNode(node2.ID, backupID2, false)
assert.NoError(t, err)
err = submitRegistry.AssignBackupToNode(node3.ID.String(), backupID3, true)
err = submitRegistry.AssignBackupToNode(node3.ID, backupID3, true)
assert.NoError(t, err)
select {
@@ -660,7 +890,7 @@ func Test_PublishBackupCompletion_PublishesMessageToChannel(t *testing.T) {
node := createTestBackupNode()
backupID := uuid.New()
err := registry.PublishBackupCompletion(node.ID.String(), backupID)
err := registry.PublishBackupCompletion(node.ID, backupID)
assert.NoError(t, err)
}
@@ -672,8 +902,8 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
receivedNodeID := make(chan string, 1)
handler := func(nodeID string, backupID uuid.UUID) {
receivedNodeID := make(chan uuid.UUID, 1)
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedNodeID <- nodeID
receivedBackupID <- backupID
}
@@ -683,12 +913,12 @@ func Test_SubscribeForBackupsCompletions_ReceivesCompletedBackups(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID)
err = registry.PublishBackupCompletion(node.ID, backupID)
assert.NoError(t, err)
select {
case receivedNode := <-receivedNodeID:
assert.Equal(t, node.ID.String(), receivedNode)
assert.Equal(t, node.ID, receivedNode)
case <-time.After(2 * time.Second):
t.Fatal("Timeout waiting for node ID")
}
@@ -710,7 +940,7 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
defer registry.UnsubscribeForBackupsCompletions()
receivedBackups := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackups <- backupID
}
@@ -719,10 +949,10 @@ func Test_SubscribeForBackupsCompletions_ParsesJsonCorrectly(t *testing.T) {
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
err = registry.PublishBackupCompletion(node.ID, backupID1)
assert.NoError(t, err)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
err = registry.PublishBackupCompletion(node.ID, backupID2)
assert.NoError(t, err)
received1 := <-receivedBackups
@@ -739,7 +969,7 @@ func Test_SubscribeForBackupsCompletions_HandlesInvalidJson(t *testing.T) {
defer registry.UnsubscribeForBackupsCompletions()
receivedBackupID := make(chan uuid.UUID, 1)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackupID <- backupID
}
@@ -767,7 +997,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
backupID2 := uuid.New()
receivedBackupID := make(chan uuid.UUID, 2)
handler := func(nodeID string, backupID uuid.UUID) {
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {
receivedBackupID <- backupID
}
@@ -776,7 +1006,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID1)
err = registry.PublishBackupCompletion(node.ID, backupID1)
assert.NoError(t, err)
received := <-receivedBackupID
@@ -787,7 +1017,7 @@ func Test_UnsubscribeForBackupsCompletions_StopsReceivingMessages(t *testing.T)
time.Sleep(100 * time.Millisecond)
err = registry.PublishBackupCompletion(node.ID.String(), backupID2)
err = registry.PublishBackupCompletion(node.ID, backupID2)
assert.NoError(t, err)
select {
@@ -802,7 +1032,7 @@ func Test_SubscribeForBackupsCompletions_WhenAlreadySubscribed_ReturnsError(t *t
registry := createTestRegistry()
defer registry.UnsubscribeForBackupsCompletions()
handler := func(nodeID string, backupID uuid.UUID) {}
handler := func(nodeID uuid.UUID, backupID uuid.UUID) {}
err := registry.SubscribeForBackupsCompletions(handler)
assert.NoError(t, err)
@@ -834,9 +1064,9 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
receivedBackups2 := make(chan uuid.UUID, 3)
receivedBackups3 := make(chan uuid.UUID, 3)
handler1 := func(nodeID string, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID string, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID string, backupID uuid.UUID) { receivedBackups3 <- backupID }
handler1 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups1 <- backupID }
handler2 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups2 <- backupID }
handler3 := func(nodeID uuid.UUID, backupID uuid.UUID) { receivedBackups3 <- backupID }
err := registry1.SubscribeForBackupsCompletions(handler1)
assert.NoError(t, err)
@@ -850,13 +1080,13 @@ func Test_MultipleSubscribers_EachReceivesCompletionMessages(t *testing.T) {
time.Sleep(100 * time.Millisecond)
publishRegistry := createTestRegistry()
err = publishRegistry.PublishBackupCompletion(node1.ID.String(), backupID1)
err = publishRegistry.PublishBackupCompletion(node1.ID, backupID1)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node2.ID.String(), backupID2)
err = publishRegistry.PublishBackupCompletion(node2.ID, backupID2)
assert.NoError(t, err)
err = publishRegistry.PublishBackupCompletion(node3.ID.String(), backupID3)
err = publishRegistry.PublishBackupCompletion(node3.ID, backupID3)
assert.NoError(t, err)
receivedAll1 := []uuid.UUID{}

View File

@@ -2,145 +2,156 @@ package backuping
import (
"context"
"databasus-backend/internal/config"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/period"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"databasus-backend/internal/config"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
files_utils "databasus-backend/internal/util/files"
)
const (
schedulerStartupDelay = 1 * time.Minute
schedulerTickerInterval = 1 * time.Minute
schedulerHealthcheckThreshold = 5 * time.Minute
)
type BackupsScheduler struct {
backupRepository *backups_core.BackupRepository
backupConfigService *backups_config.BackupConfigService
storageService *storages.StorageService
backupCancelManager *backups_cancellation.BackupCancelManager
nodesRegistry *BackupNodesRegistry
taskCancelManager *task_cancellation.TaskCancelManager
backupNodesRegistry *BackupNodesRegistry
databaseService *databases.DatabaseService
lastBackupTime time.Time
logger *slog.Logger
backupToNodeRelations map[uuid.UUID]BackupToNodeRelation
backuperNode *BackuperNode
runOnce sync.Once
hasRun atomic.Bool
}
func (s *BackupsScheduler) Run(ctx context.Context) {
s.lastBackupTime = time.Now().UTC()
wasAlreadyRun := s.hasRun.Load()
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(1 * time.Minute)
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
s.lastBackupTime = time.Now().UTC()
if err := s.nodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted); err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
defer func() {
if err := s.nodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
if config.GetEnv().IsManyNodesMode {
// wait other nodes to start
time.Sleep(schedulerStartupDelay)
}
}()
if ctx.Err() != nil {
return
}
if err := s.failBackupsInProgress(); err != nil {
s.logger.Error("Failed to fail backups in progress", "error", err)
panic(err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
err := s.backupNodesRegistry.SubscribeForBackupsCompletions(s.onBackupCompleted)
if err != nil {
s.logger.Error("Failed to subscribe to backup completions", "error", err)
panic(err)
}
for {
select {
case <-ctx.Done():
defer func() {
if err := s.backupNodesRegistry.UnsubscribeForBackupsCompletions(); err != nil {
s.logger.Error("Failed to unsubscribe from backup completions", "error", err)
}
}()
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.cleanOldBackups(); err != nil {
s.logger.Error("Failed to clean old backups", "error", err)
}
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
ticker := time.NewTicker(schedulerTickerInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.checkDeadNodesAndFailBackups(); err != nil {
s.logger.Error("Failed to check dead nodes and fail backups", "error", err)
}
if err := s.runPendingBackups(); err != nil {
s.logger.Error("Failed to run pending backups", "error", err)
}
s.lastBackupTime = time.Now().UTC()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}
func (s *BackupsScheduler) IsSchedulerRunning() bool {
// if last backup time is more than 5 minutes ago, return false
return s.lastBackupTime.After(time.Now().UTC().Add(-5 * time.Minute))
return s.lastBackupTime.After(time.Now().UTC().Add(-schedulerHealthcheckThreshold))
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
func (s *BackupsScheduler) IsBackupNodesAvailable() bool {
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return err
s.logger.Error("Failed to get available nodes for health check", "error", err)
return false
}
fmt.Println("Backups in progress", len(backupsInProgress))
for _, backup := range backupsInProgress {
if err := s.backupCancelManager.CancelBackup(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via context manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
return len(nodes) > 0
}
func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(databaseID)
func (s *BackupsScheduler) StartBackup(database *databases.Database, isCallNotifier bool) {
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(database.ID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
return
}
if backupConfig.StorageID == nil {
s.logger.Error("Backup config storage ID is nil", "databaseId", databaseID)
s.logger.Error("Backup config storage ID is nil", "databaseId", database.ID)
return
}
// Check for existing in-progress backups
inProgressBackups, err := s.backupRepository.FindByDatabaseIdAndStatus(
database.ID,
backups_core.BackupStatusInProgress,
)
if err != nil {
s.logger.Error(
"Failed to check for in-progress backups",
"databaseId",
database.ID,
"error",
err,
)
return
}
if len(inProgressBackups) > 0 {
s.logger.Warn(
"Backup already in progress for database, skipping new backup",
"databaseId",
database.ID,
"existingBackupId",
inProgressBackups[0].ID,
)
return
}
@@ -156,12 +167,22 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
backupID := uuid.New()
timestamp := time.Now().UTC()
backup := &backups_core.Backup{
ID: backupID,
FileName: fmt.Sprintf(
"%s-%s-%s",
files_utils.SanitizeFilename(database.Name),
timestamp.Format("20060102-150405"),
backupID.String(),
),
DatabaseID: backupConfig.DatabaseID,
StorageID: *backupConfig.StorageID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
CreatedAt: timestamp,
}
if err := s.backupRepository.Save(backup); err != nil {
@@ -175,7 +196,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.IncrementBackupsInProgress(leastBusyNodeID.String()); err != nil {
if err := s.backupNodesRegistry.IncrementBackupsInProgress(*leastBusyNodeID); err != nil {
s.logger.Error(
"Failed to increment backups in progress",
"nodeId",
@@ -188,7 +209,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
return
}
if err := s.nodesRegistry.AssignBackupToNode(leastBusyNodeID.String(), backup.ID, isCallNotifier); err != nil {
if err := s.backupNodesRegistry.AssignBackupToNode(*leastBusyNodeID, backup.ID, isCallNotifier); err != nil {
s.logger.Error(
"Failed to submit backup",
"nodeId",
@@ -198,7 +219,7 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
"error",
err,
)
if decrementErr := s.nodesRegistry.DecrementBackupsInProgress(leastBusyNodeID.String()); decrementErr != nil {
if decrementErr := s.backupNodesRegistry.DecrementBackupsInProgress(*leastBusyNodeID); decrementErr != nil {
s.logger.Error(
"Failed to decrement backups in progress after submit failure",
"nodeId",
@@ -215,8 +236,8 @@ func (s *BackupsScheduler) StartBackup(databaseID uuid.UUID, isCallNotifier bool
s.backupToNodeRelations[*leastBusyNodeID] = relation
} else {
s.backupToNodeRelations[*leastBusyNodeID] = BackupToNodeRelation{
NodeID: *leastBusyNodeID,
BackupsIDs: []uuid.UUID{backup.ID},
*leastBusyNodeID,
[]uuid.UUID{backup.ID},
}
}
@@ -244,6 +265,10 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return 0
}
if lastBackup.IsSkipRetry {
return 0
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(lastBackup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
@@ -276,74 +301,6 @@ func (s *BackupsScheduler) GetRemainedBackupTryCount(lastBackup *backups_core.Ba
return maxFailedTriesCount - len(lastFailedBackups)
}
func (s *BackupsScheduler) cleanOldBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
return err
}
for _, backupConfig := range enabledBackupConfigs {
backupStorePeriod := backupConfig.StorePeriod
if backupStorePeriod == period.PeriodForever {
continue
}
storeDuration := backupStorePeriod.ToDuration()
dateBeforeBackupsShouldBeDeleted := time.Now().UTC().Add(-storeDuration)
oldBackups, err := s.backupRepository.FindBackupsBeforeDate(
backupConfig.DatabaseID,
dateBeforeBackupsShouldBeDeleted,
)
if err != nil {
s.logger.Error(
"Failed to find old backups for database",
"databaseId",
backupConfig.DatabaseID,
"error",
err,
)
continue
}
for _, backup := range oldBackups {
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
s.logger.Error(
"Failed to get storage by ID",
"storageId",
backup.StorageID,
"error",
err,
)
continue
}
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}
if err := s.backupRepository.DeleteByID(backup.ID); err != nil {
s.logger.Error("Failed to delete old backup", "backupId", backup.ID, "error", err)
continue
}
s.logger.Info(
"Deleted old backup",
"backupId",
backup.ID,
"databaseId",
backupConfig.DatabaseID,
)
}
}
return nil
}
func (s *BackupsScheduler) runPendingBackups() error {
enabledBackupConfigs, err := s.backupConfigService.GetBackupConfigsWithEnabledBackups()
if err != nil {
@@ -384,7 +341,13 @@ func (s *BackupsScheduler) runPendingBackups() error {
backupConfig.BackupInterval.Interval,
)
s.StartBackup(backupConfig.DatabaseID, remainedBackupTryCount == 1)
database, err := s.databaseService.GetDatabaseByID(backupConfig.DatabaseID)
if err != nil {
s.logger.Error("Failed to get database by ID", "error", err)
continue
}
s.StartBackup(database, remainedBackupTryCount == 1)
continue
}
}
@@ -392,8 +355,51 @@ func (s *BackupsScheduler) runPendingBackups() error {
return nil
}
func (s *BackupsScheduler) failBackupsInProgress() error {
backupsInProgress, err := s.backupRepository.FindByStatus(backups_core.BackupStatusInProgress)
if err != nil {
return err
}
for _, backup := range backupsInProgress {
if err := s.taskCancelManager.CancelTask(backup.ID); err != nil {
s.logger.Error(
"Failed to cancel backup via task cancel manager",
"backupId",
backup.ID,
"error",
err,
)
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(backup.DatabaseID)
if err != nil {
s.logger.Error("Failed to get backup config by database ID", "error", err)
continue
}
failMessage := "Backup failed due to application restart"
backup.FailMessage = &failMessage
backup.Status = backups_core.BackupStatusFailed
backup.BackupSizeMb = 0
s.backuperNode.SendBackupNotification(
backupConfig,
backup,
backups_config.NotificationBackupFailed,
&failMessage,
)
if err := s.backupRepository.Save(backup); err != nil {
return err
}
}
return nil
}
func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, fmt.Errorf("failed to get available nodes: %w", err)
}
@@ -402,7 +408,7 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return nil, fmt.Errorf("no nodes available")
}
stats, err := s.nodesRegistry.GetBackupNodesStats()
stats, err := s.backupNodesRegistry.GetBackupNodesStats()
if err != nil {
return nil, fmt.Errorf("failed to get backup nodes stats: %w", err)
}
@@ -415,14 +421,9 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
var bestNode *BackupNode
var bestScore float64 = -1
now := time.Now().UTC()
for i := range nodes {
node := &nodes[i]
if now.Sub(node.LastHeartbeat) > 2*time.Minute {
continue
}
activeBackups := statsMap[node.ID]
var score float64
@@ -445,16 +446,11 @@ func (s *BackupsScheduler) calculateLeastBusyNode() (*uuid.UUID, error) {
return &bestNode.ID, nil
}
func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUID) {
nodeID, err := uuid.Parse(nodeIDStr)
func (s *BackupsScheduler) onBackupCompleted(nodeID uuid.UUID, backupID uuid.UUID) {
// Verify this task is actually a backup (registry contains multiple task types)
_, err := s.backupRepository.FindByID(backupID)
if err != nil {
s.logger.Error(
"Failed to parse node ID from completion message",
"nodeId",
nodeIDStr,
"error",
err,
)
// Not a backup task, ignore it
return
}
@@ -498,7 +494,7 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
s.backupToNodeRelations[nodeID] = relation
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeIDStr); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress",
"nodeId",
@@ -512,18 +508,14 @@ func (s *BackupsScheduler) onBackupCompleted(nodeIDStr string, backupID uuid.UUI
}
func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
nodes, err := s.nodesRegistry.GetAvailableNodes()
nodes, err := s.backupNodesRegistry.GetAvailableNodes()
if err != nil {
return fmt.Errorf("failed to get available nodes: %w", err)
}
aliveNodeIDs := make(map[uuid.UUID]bool)
now := time.Now().UTC()
for _, node := range nodes {
if now.Sub(node.LastHeartbeat) <= 2*time.Minute {
aliveNodeIDs[node.ID] = true
}
aliveNodeIDs[node.ID] = true
}
for nodeID, relation := range s.backupToNodeRelations {
@@ -572,7 +564,7 @@ func (s *BackupsScheduler) checkDeadNodesAndFailBackups() error {
continue
}
if err := s.nodesRegistry.DecrementBackupsInProgress(nodeID.String()); err != nil {
if err := s.backupNodesRegistry.DecrementBackupsInProgress(nodeID); err != nil {
s.logger.Error(
"Failed to decrement backups in progress for dead node",
"nodeId",

View File

@@ -42,8 +42,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -57,7 +57,8 @@ func Test_RunPendingBackups_WhenLastBackupWasYesterday_CreatesNewBackup(t *testi
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -111,8 +112,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -126,7 +127,8 @@ func Test_RunPendingBackups_WhenLastBackupWasRecentlyCompleted_SkipsBackup(t *te
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -179,8 +181,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -194,7 +196,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesDisabled_SkipsBackup(t
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = false
@@ -251,8 +254,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -266,7 +269,8 @@ func Test_RunPendingBackups_WhenLastBackupFailedAndRetriesEnabled_CreatesNewBack
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
@@ -324,8 +328,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -339,7 +343,8 @@ func Test_RunPendingBackups_WhenFailedBackupsExceedMaxRetries_SkipsBackup(t *tes
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
@@ -396,8 +401,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
@@ -410,7 +415,8 @@ func Test_RunPendingBackups_WhenBackupsDisabled_SkipsBackup(t *testing.T) {
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = false
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -449,6 +455,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
@@ -457,9 +465,15 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -471,7 +485,8 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -479,12 +494,12 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.NoError(t, err)
// Register mock node without subscribing to backups (simulates node crash after registration)
mockNodeID := uuid.New()
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Scheduler assigns backup to mock node
GetBackupsScheduler().StartBackup(database.ID, false)
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
@@ -493,7 +508,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Verify Valkey counter was incremented when backup was assigned
stats, err := nodesRegistry.GetBackupNodesStats()
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
foundStat := false
for _, stat := range stats {
@@ -523,7 +538,7 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
assert.Contains(t, *backups[0].FailMessage, "node unavailability")
// Verify Valkey counter was decremented after backup failed
stats, err = nodesRegistry.GetBackupNodesStats()
stats, err = backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
@@ -531,11 +546,110 @@ func Test_CheckDeadNodesAndFailBackups_WhenNodeDies_FailsBackupAndCleansUpRegist
}
}
// Node info should still exist in registry (not removed by checkDeadNodesAndFailBackups)
node, err := GetNodeFromRegistry(mockNodeID)
time.Sleep(200 * time.Millisecond)
}
func Test_OnBackupCompleted_WhenTaskIsNotBackup_SkipsProcessing(t *testing.T) {
cache_utils.ClearAllCache()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
var mockNodeID uuid.UUID
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
// Clean up mock node
if mockNodeID != uuid.Nil {
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: mockNodeID})
}
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
assert.NotNil(t, node)
assert.Equal(t, mockNodeID, node.ID)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Register mock node
mockNodeID = uuid.New()
err = CreateMockNodeInRegistry(mockNodeID, 100, time.Now().UTC())
assert.NoError(t, err)
// Start a backup and assign it to the node
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(100 * time.Millisecond)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
// Get initial state of the registry
initialStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range initialStats {
if stat.ID == mockNodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
assert.Equal(t, 1, initialActiveTasks, "Should have 1 active task")
// Call onBackupCompleted with a random UUID (not a backup ID)
nonBackupTaskID := uuid.New()
GetBackupsScheduler().onBackupCompleted(mockNodeID, nonBackupTaskID)
time.Sleep(100 * time.Millisecond)
// Verify: Active tasks counter should remain the same (not decremented)
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range stats {
if stat.ID == mockNodeID {
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active tasks should not change for non-backup task")
}
}
// Verify: backup should still be in progress (not modified)
backups, err = backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status,
"Backup status should not change for non-backup task completion")
// Verify: backupToNodeRelations should still contain the node
scheduler := GetBackupsScheduler()
_, exists := scheduler.backupToNodeRelations[mockNodeID]
assert.True(t, exists, "Node should still be in backupToNodeRelations")
time.Sleep(200 * time.Millisecond)
}
@@ -549,6 +663,14 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node3ID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node1ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node2ID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node3ID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node1ID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node2ID, 100, now)
@@ -557,17 +679,17 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
assert.NoError(t, err)
for range 5 {
err = nodesRegistry.IncrementBackupsInProgress(node1ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node1ID)
assert.NoError(t, err)
}
for range 2 {
err = nodesRegistry.IncrementBackupsInProgress(node2ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node2ID)
assert.NoError(t, err)
}
for range 8 {
err = nodesRegistry.IncrementBackupsInProgress(node3ID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node3ID)
assert.NoError(t, err)
}
@@ -584,17 +706,24 @@ func Test_CalculateLeastBusyNode_SelectsNodeWithBestScore(t *testing.T) {
node50MBsID := uuid.New()
now := time.Now().UTC()
defer func() {
// Clean up all mock nodes
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node100MBsID})
backupNodesRegistry.UnregisterNodeFromRegistry(BackupNode{ID: node50MBsID})
cache_utils.ClearAllCache()
}()
err := CreateMockNodeInRegistry(node100MBsID, 100, now)
assert.NoError(t, err)
err = CreateMockNodeInRegistry(node50MBsID, 50, now)
assert.NoError(t, err)
for range 10 {
err = nodesRegistry.IncrementBackupsInProgress(node100MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node100MBsID)
assert.NoError(t, err)
}
err = nodesRegistry.IncrementBackupsInProgress(node50MBsID.String())
err = backupNodesRegistry.IncrementBackupsInProgress(node50MBsID)
assert.NoError(t, err)
leastBusyNodeID, err := GetBackupsScheduler().calculateLeastBusyNode()
@@ -622,9 +751,11 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -636,7 +767,8 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.StorePeriod = period.PeriodWeek
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
@@ -707,3 +839,498 @@ func Test_FailBackupsInProgress_WhenSchedulerStarts_CancelsBackupsAndUpdatesStat
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupCompletes_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database, false)
// Wait for backup to complete
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and completed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusCompleted, backups[0].Status)
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup completion")
// Verify final active task count equals initial count
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup completion")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupFails_DecrementsActiveTaskCount(t *testing.T) {
cache_utils.ClearAllCache()
// Start scheduler so it can handle task completions
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Set wrong password to cause backup failure
// We need to bypass service layer validation which would fail on connection test
database.Postgresql.Password = "intentionally_wrong_password"
dbRepo := &databases.DatabaseRepository{}
_, err := dbRepo.Save(database)
assert.NoError(t, err)
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Get initial active task count
stats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
var initialActiveTasks int
for _, stat := range stats {
if stat.ID == backuperNode.nodeID {
initialActiveTasks = stat.ActiveBackups
break
}
}
t.Logf("Initial active tasks: %d", initialActiveTasks)
// Start backup
scheduler.StartBackup(database, false)
// Wait for backup to fail
WaitForBackupCompletion(t, database.ID, 0, 10*time.Second)
// Verify backup was created and failed
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusFailed, backups[0].Status)
assert.NotNil(t, backups[0].FailMessage)
if backups[0].FailMessage != nil {
t.Logf("Backup failed with message: %s", *backups[0].FailMessage)
}
// Wait for active task count to decrease
decreased := WaitForActiveTasksDecrease(
t,
backuperNode.nodeID,
initialActiveTasks+1,
10*time.Second,
)
assert.True(t, decreased, "Active task count should have decreased after backup failure")
// Verify final active task count equals initial count
finalStats, err := backupNodesRegistry.GetBackupNodesStats()
assert.NoError(t, err)
for _, stat := range finalStats {
if stat.ID == backuperNode.nodeID {
t.Logf("Final active tasks: %d", stat.ActiveBackups)
assert.Equal(t, initialActiveTasks, stat.ActiveBackups,
"Active task count should return to initial value after backup failure")
break
}
}
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_WhenBackupAlreadyInProgress_SkipsNewBackup(t *testing.T) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create an in-progress backup manually
inProgressBackup := &backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusInProgress,
BackupSizeMb: 0,
CreatedAt: time.Now().UTC(),
}
err = backupRepository.Save(inProgressBackup)
assert.NoError(t, err)
// Try to start a new backup - should be skipped
GetBackupsScheduler().StartBackup(database, false)
time.Sleep(200 * time.Millisecond)
// Verify only 1 backup exists (the original in-progress one)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1)
assert.Equal(t, backups_core.BackupStatusInProgress, backups[0].Status)
assert.Equal(t, inProgressBackup.ID, backups[0].ID)
time.Sleep(200 * time.Millisecond)
}
func Test_RunPendingBackups_WhenLastBackupFailedWithIsSkipRetry_SkipsBackupEvenWithRetriesEnabled(
t *testing.T,
) {
cache_utils.ClearAllCache()
backuperNode := CreateTestBackuperNode()
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups with retries enabled and high retry count
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig.IsBackupsEnabled = true
backupConfig.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig.RetentionTimePeriod = period.PeriodWeek
backupConfig.Storage = storage
backupConfig.StorageID = &storage.ID
backupConfig.IsRetryIfFailed = true
backupConfig.MaxFailedTriesCount = 5
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig)
assert.NoError(t, err)
// Create a failed backup with IsSkipRetry set to true
failMessage := "backup failed due to size limit exceeded"
backupRepository.Save(&backups_core.Backup{
DatabaseID: database.ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusFailed,
FailMessage: &failMessage,
IsSkipRetry: true,
CreatedAt: time.Now().UTC().Add(-1 * time.Hour),
})
// Verify GetRemainedBackupTryCount returns 0 even though retries are enabled
lastBackup, err := backupRepository.FindLastByDatabaseID(database.ID)
assert.NoError(t, err)
assert.NotNil(t, lastBackup)
remainedTries := GetBackupsScheduler().GetRemainedBackupTryCount(lastBackup)
assert.Equal(t, 0, remainedTries, "Should return 0 tries when IsSkipRetry is true")
// Run the scheduler
GetBackupsScheduler().runPendingBackups()
time.Sleep(100 * time.Millisecond)
// Verify no new backup was created (still only 1 backup exists)
backups, err := backupRepository.FindByDatabaseID(database.ID)
assert.NoError(t, err)
assert.Len(t, backups, 1, "No retry should be attempted when IsSkipRetry is true")
time.Sleep(200 * time.Millisecond)
}
func Test_StartBackup_When2BackupsStartedForDifferentDatabases_BothUseCasesAreCalled(t *testing.T) {
cache_utils.ClearAllCache()
// Create mock tracking use case
mockUseCase := NewMockTrackingBackupUsecase()
// Create BackuperNode with mock use case
backuperNode := CreateTestBackuperNodeWithUseCase(mockUseCase)
cancel := StartBackuperNodeForTest(t, backuperNode)
defer StopBackuperNodeForTest(t, cancel, backuperNode)
// Create scheduler
scheduler := CreateTestScheduler()
schedulerCancel := StartSchedulerForTest(t, scheduler)
defer schedulerCancel()
// Setup test data
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
// Create 2 separate databases
database1 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
database2 := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
// Cleanup backups for database1
backups1, _ := backupRepository.FindByDatabaseID(database1.ID)
for _, backup := range backups1 {
backupRepository.DeleteByID(backup.ID)
}
// Cleanup backups for database2
backups2, _ := backupRepository.FindByDatabaseID(database2.ID)
for _, backup := range backups2 {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Enable backups for database1
backupConfig1, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database1.ID)
assert.NoError(t, err)
timeOfDay := "04:00"
backupConfig1.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig1.IsBackupsEnabled = true
backupConfig1.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig1.RetentionTimePeriod = period.PeriodWeek
backupConfig1.Storage = storage
backupConfig1.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig1)
assert.NoError(t, err)
// Enable backups for database2
backupConfig2, err := backups_config.GetBackupConfigService().
GetBackupConfigByDbId(database2.ID)
assert.NoError(t, err)
backupConfig2.BackupInterval = &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
}
backupConfig2.IsBackupsEnabled = true
backupConfig2.RetentionPolicyType = backups_config.RetentionPolicyTypeTimePeriod
backupConfig2.RetentionTimePeriod = period.PeriodWeek
backupConfig2.Storage = storage
backupConfig2.StorageID = &storage.ID
_, err = backups_config.GetBackupConfigService().SaveBackupConfig(backupConfig2)
assert.NoError(t, err)
// Start 2 backups simultaneously
t.Log("Starting backup for database1")
scheduler.StartBackup(database1, false)
t.Log("Starting backup for database2")
scheduler.StartBackup(database2, false)
// Wait up to 10 seconds for both backups to complete
t.Log("Waiting for both backups to complete...")
success := assert.Eventually(t, func() bool {
callCount := mockUseCase.GetCallCount()
t.Logf("Current call count: %d/2", callCount)
return callCount == 2
}, 10*time.Second, 200*time.Millisecond, "Both use cases should be called within 10 seconds")
if !success {
t.Logf("Test failed: Only %d out of 2 use cases were called", mockUseCase.GetCallCount())
}
// Verify both backup IDs were received
calledBackupIDs := mockUseCase.GetCalledBackupIDs()
t.Logf("Called backup IDs: %v", calledBackupIDs)
assert.Len(t, calledBackupIDs, 2, "Both backup IDs should be tracked")
// Verify both backups exist in repository and are completed
backups1, err := backupRepository.FindByDatabaseID(database1.ID)
assert.NoError(t, err)
assert.Len(t, backups1, 1, "Database1 should have 1 backup")
if len(backups1) > 0 {
t.Logf("Database1 backup status: %s", backups1[0].Status)
}
backups2, err := backupRepository.FindByDatabaseID(database2.ID)
assert.NoError(t, err)
assert.Len(t, backups2, 1, "Database2 should have 1 backup")
if len(backups2) > 0 {
t.Logf("Database2 backup status: %s", backups2[0].Status)
}
// Verify both backups completed successfully
if len(backups1) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups1[0].Status,
"Database1 backup should be completed")
}
if len(backups2) > 0 {
assert.Equal(t, backups_core.BackupStatusCompleted, backups2[0].Status,
"Database2 backup should be completed")
}
time.Sleep(200 * time.Millisecond)
}

View File

@@ -3,6 +3,8 @@ package backuping
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
@@ -35,19 +37,56 @@ func CreateTestRouter() *gin.Engine {
func CreateTestBackuperNode() *BackuperNode {
return &BackuperNode{
databases.GetDatabaseService(),
encryption.GetFieldEncryptor(),
workspaces_services.GetWorkspaceService(),
backupRepository,
backups_config.GetBackupConfigService(),
storages.GetStorageService(),
notifiers.GetNotifierService(),
backupCancelManager,
nodesRegistry,
logger.GetLogger(),
usecases.GetCreateBackupUsecase(),
uuid.New(),
time.Time{},
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestBackuperNodeWithUseCase(useCase backups_core.CreateBackupUsecase) *BackuperNode {
return &BackuperNode{
databaseService: databases.GetDatabaseService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
workspaceService: workspaces_services.GetWorkspaceService(),
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
storageService: storages.GetStorageService(),
notificationSender: notifiers.GetNotifierService(),
backupCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
logger: logger.GetLogger(),
createBackupUseCase: useCase,
nodeID: uuid.New(),
lastHeartbeat: time.Time{},
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func CreateTestScheduler() *BackupsScheduler {
return &BackupsScheduler{
backupRepository: backupRepository,
backupConfigService: backups_config.GetBackupConfigService(),
taskCancelManager: taskCancelManager,
backupNodesRegistry: backupNodesRegistry,
lastBackupTime: time.Now().UTC(),
logger: logger.GetLogger(),
backupToNodeRelations: make(map[uuid.UUID]BackupToNodeRelation),
backuperNode: CreateTestBackuperNode(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
@@ -113,7 +152,7 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
// Poll registry for node presence instead of fixed sleep
deadline := time.Now().UTC().Add(5 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
for _, node := range nodes {
if node.ID == backuperNode.nodeID {
@@ -138,6 +177,34 @@ func StartBackuperNodeForTest(t *testing.T, backuperNode *BackuperNode) context.
return nil
}
// StartSchedulerForTest starts the BackupsScheduler in a goroutine for testing.
// The scheduler subscribes to task completions and manages backup lifecycle.
// Returns a context cancel function that should be deferred to stop the scheduler.
func StartSchedulerForTest(t *testing.T, scheduler *BackupsScheduler) context.CancelFunc {
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
scheduler.Run(ctx)
close(done)
}()
// Give scheduler time to subscribe to completions
time.Sleep(100 * time.Millisecond)
t.Log("BackupsScheduler started")
return func() {
cancel()
select {
case <-done:
t.Log("BackupsScheduler stopped gracefully")
case <-time.After(2 * time.Second):
t.Log("BackupsScheduler stop timeout")
}
}
}
// StopBackuperNodeForTest stops the BackuperNode by canceling its context.
// It waits for the node to unregister from the registry.
func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNode *BackuperNode) {
@@ -146,7 +213,7 @@ func StopBackuperNodeForTest(t *testing.T, cancel context.CancelFunc, backuperNo
// Wait for node to unregister from registry
deadline := time.Now().UTC().Add(2 * time.Second)
for time.Now().UTC().Before(deadline) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err == nil {
found := false
for _, node := range nodes {
@@ -173,7 +240,7 @@ func CreateMockNodeInRegistry(nodeID uuid.UUID, throughputMBs int, lastHeartbeat
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func UpdateNodeHeartbeatDirectly(
@@ -187,11 +254,11 @@ func UpdateNodeHeartbeatDirectly(
LastHeartbeat: lastHeartbeat,
}
return nodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
return backupNodesRegistry.HearthbeatNodeInRegistry(lastHeartbeat, backupNode)
}
func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
nodes, err := nodesRegistry.GetAvailableNodes()
nodes, err := backupNodesRegistry.GetAvailableNodes()
if err != nil {
return nil, err
}
@@ -204,3 +271,48 @@ func GetNodeFromRegistry(nodeID uuid.UUID) (*BackupNode, error) {
return nil, fmt.Errorf("node not found")
}
// WaitForActiveTasksDecrease waits for the active task count to decrease below the initial count.
// It polls the registry every 500ms until the count decreases or the timeout is reached.
// Returns true if the count decreased, false if timeout was reached.
func WaitForActiveTasksDecrease(
t *testing.T,
nodeID uuid.UUID,
initialCount int,
timeout time.Duration,
) bool {
deadline := time.Now().UTC().Add(timeout)
for time.Now().UTC().Before(deadline) {
stats, err := backupNodesRegistry.GetBackupNodesStats()
if err != nil {
t.Logf("WaitForActiveTasksDecrease: error getting node stats: %v", err)
time.Sleep(500 * time.Millisecond)
continue
}
for _, stat := range stats {
if stat.ID == nodeID {
t.Logf(
"WaitForActiveTasksDecrease: current active tasks = %d (initial = %d)",
stat.ActiveBackups,
initialCount,
)
if stat.ActiveBackups < initialCount {
t.Logf(
"WaitForActiveTasksDecrease: active tasks decreased from %d to %d",
initialCount,
stat.ActiveBackups,
)
return true
}
break
}
}
time.Sleep(500 * time.Millisecond)
}
t.Logf("WaitForActiveTasksDecrease: timeout waiting for active tasks to decrease")
return false
}

View File

@@ -1,75 +0,0 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"log/slog"
"sync"
"github.com/google/uuid"
)
const backupCancelChannel = "backup:cancel"
type BackupCancelManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
pubsub *cache_utils.PubSubManager
logger *slog.Logger
}
func (m *BackupCancelManager) StartSubscription() {
ctx := context.Background()
handler := func(message string) {
backupID, err := uuid.Parse(message)
if err != nil {
m.logger.Error("Invalid backup ID in cancel message", "message", message, "error", err)
return
}
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if exists {
cancelFunc()
delete(m.cancelFuncs, backupID)
m.logger.Info("Cancelled backup via Pub/Sub", "backupID", backupID)
}
}
err := m.pubsub.Subscribe(ctx, backupCancelChannel, handler)
if err != nil {
m.logger.Error("Failed to subscribe to backup cancel channel", "error", err)
} else {
m.logger.Info("Successfully subscribed to backup cancel channel")
}
}
func (m *BackupCancelManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
m.logger.Debug("Registered backup", "backupID", backupID)
}
func (m *BackupCancelManager) CancelBackup(backupID uuid.UUID) error {
ctx := context.Background()
err := m.pubsub.Publish(ctx, backupCancelChannel, backupID.String())
if err != nil {
m.logger.Error("Failed to publish cancel message", "backupID", backupID, "error", err)
return err
}
m.logger.Info("Published backup cancel message", "backupID", backupID)
return nil
}
func (m *BackupCancelManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
m.logger.Debug("Unregistered backup", "backupID", backupID)
}

View File

@@ -1,25 +0,0 @@
package backups_cancellation
import (
"context"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
"sync"
"github.com/google/uuid"
)
var backupCancelManager = &BackupCancelManager{
sync.RWMutex{},
make(map[uuid.UUID]context.CancelFunc),
cache_utils.NewPubSubManager(),
logger.GetLogger(),
}
func GetBackupCancelManager() *BackupCancelManager {
return backupCancelManager
}
func SetupDependencies() {
backupCancelManager.StartSubscription()
}

View File

@@ -1,17 +1,37 @@
package common
import backups_config "databasus-backend/internal/features/backups/config"
import (
backups_config "databasus-backend/internal/features/backups/config"
"errors"
type BackupType string
const (
BackupTypeDefault BackupType = "DEFAULT" // For MySQL, MongoDB, PostgreSQL legacy (-Fc)
BackupTypeDirectory BackupType = "DIRECTORY" // PostgreSQL directory type (-Fd)
"github.com/google/uuid"
)
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
Type BackupType
BackupID uuid.UUID `json:"backupId"`
EncryptionSalt *string `json:"encryptionSalt"`
EncryptionIV *string `json:"encryptionIV"`
Encryption backups_config.BackupEncryption `json:"encryption"`
}
func (m *BackupMetadata) Validate() error {
if m.BackupID == uuid.Nil {
return errors.New("backup ID is required")
}
if m.Encryption == "" {
return errors.New("encryption is required")
}
if m.Encryption == backups_config.BackupEncryptionEncrypted {
if m.EncryptionSalt == nil {
return errors.New("encryption salt is required when encryption is enabled")
}
if m.EncryptionIV == nil {
return errors.New("encryption IV is required when encryption is enabled")
}
}
return nil
}

View File

@@ -1,12 +1,16 @@
package backups
import (
"context"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/databases"
users_middleware "databasus-backend/internal/features/users/middleware"
files_utils "databasus-backend/internal/util/files"
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -174,6 +178,7 @@ func (c *BackupController) CancelBackup(ctx *gin.Context) {
// @Success 200 {object} backups_download.GenerateDownloadTokenResponse
// @Failure 400
// @Failure 401
// @Failure 409 {object} map[string]string "Download already in progress"
// @Router /backups/{id}/download-token [post]
func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
@@ -190,6 +195,15 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
response, err := c.backupService.GenerateDownloadToken(user, id)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "Download already in progress for some of backups. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -199,14 +213,22 @@ func (c *BackupController) GenerateDownloadToken(ctx *gin.Context) {
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup using a download token
// @Description Download the backup file for the specified backup using a download token.
// @Description
// @Description **Download Concurrency Control:**
// @Description - Only one download per user is allowed at a time
// @Description - If a download is already in progress, returns 409 Conflict
// @Description - Downloads are tracked using cache with 5-second TTL and 3-second heartbeat
// @Description - Browser cancellations automatically release the download lock
// @Description - Server crashes are handled via automatic cache expiry (5 seconds)
// @Tags backups
// @Param id path string true "Backup ID"
// @Param token query string true "Download token"
// @Success 200 {file} file
// @Failure 400
// @Failure 401
// @Failure 500
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 409 {object} map[string]string "Download already in progress"
// @Failure 500 {object} map[string]string
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
token := ctx.Query("token")
@@ -215,7 +237,6 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
// Get backup ID from URL
backupIDParam := ctx.Param("id")
backupID, err := uuid.Parse(backupIDParam)
if err != nil {
@@ -223,13 +244,22 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
return
}
downloadToken, err := c.backupService.ValidateDownloadToken(token)
downloadToken, rateLimiter, err := c.backupService.ValidateDownloadToken(token)
if err != nil {
if err == backups_download.ErrDownloadAlreadyInProgress {
ctx.JSON(
http.StatusConflict,
gin.H{
"error": "download already in progress for this user. Please wait until previous download completed or cancel it",
},
)
return
}
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
}
// Verify token is for the requested backup
if downloadToken.BackupID != backupID {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired download token"})
return
@@ -239,18 +269,28 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
downloadToken.BackupID,
)
if err != nil {
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
rateLimitedReader := backups_download.NewRateLimitedReader(fileReader, rateLimiter)
heartbeatCtx, cancelHeartbeat := context.WithCancel(context.Background())
defer func() {
if err := fileReader.Close(); err != nil {
cancelHeartbeat()
c.backupService.UnregisterDownload(downloadToken.UserID)
c.backupService.ReleaseDownloadLock(downloadToken.UserID)
if err := rateLimitedReader.Close(); err != nil {
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
go c.startDownloadHeartbeat(heartbeatCtx, downloadToken.UserID)
filename := c.generateBackupFilename(backup, database)
// Set Content-Length for progress tracking
if backup.BackupSizeMb > 0 {
sizeBytes := int64(backup.BackupSizeMb * 1024 * 1024)
ctx.Header("Content-Length", fmt.Sprintf("%d", sizeBytes))
@@ -262,13 +302,11 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
fmt.Sprintf("attachment; filename=\"%s\"", filename),
)
_, err = io.Copy(ctx.Writer, fileReader)
_, err = io.Copy(ctx.Writer, rateLimitedReader)
if err != nil {
fmt.Printf("Error streaming file: %v\n", err)
return
}
// Write audit log after successful download
c.backupService.WriteAuditLogForDownload(downloadToken.UserID, backup, database)
}
@@ -284,7 +322,7 @@ func (c *BackupController) generateBackupFilename(
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
// Sanitize database name for filename (replace spaces and special chars)
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
// Determine extension based on database type
extension := c.getBackupExtension(database.Type)
@@ -308,29 +346,16 @@ func (c *BackupController) getBackupExtension(
}
}
func sanitizeFilename(name string) string {
// Replace characters that are invalid in filenames
replacer := map[rune]rune{
' ': '_',
'/': '-',
'\\': '-',
':': '-',
'*': '-',
'?': '-',
'"': '-',
'<': '-',
'>': '-',
'|': '-',
}
func (c *BackupController) startDownloadHeartbeat(ctx context.Context, userID uuid.UUID) {
ticker := time.NewTicker(backups_download.GetDownloadHeartbeatInterval())
defer ticker.Stop()
result := make([]rune, 0, len(name))
for _, char := range name {
if replacement, exists := replacer[char]; exists {
result = append(result, replacement)
} else {
result = append(result, char)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
c.backupService.RefreshDownloadLock(userID)
}
}
return string(result)
}

File diff suppressed because it is too large Load Diff

View File

@@ -8,8 +8,6 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type NotificationSender interface {
@@ -23,7 +21,7 @@ type NotificationSender interface {
type CreateBackupUsecase interface {
Execute(
ctx context.Context,
backupID uuid.UUID,
backup *Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,

View File

@@ -8,13 +8,15 @@ import (
)
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
FileName string `json:"fileName" gorm:"column:file_name;type:text;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
IsSkipRetry bool `json:"isSkipRetry" gorm:"column:is_skip_retry;type:boolean;not null"`
BackupSizeMb float64 `json:"backupSizeMb" gorm:"column:backup_size_mb;default:0"`

View File

@@ -212,3 +212,36 @@ func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error
return count, nil
}
func (r *BackupRepository) GetTotalSizeByDatabase(databaseID uuid.UUID) (float64, error) {
var totalSize float64
if err := storage.
GetDb().
Model(&Backup{}).
Select("COALESCE(SUM(backup_size_mb), 0)").
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Scan(&totalSize).Error; err != nil {
return 0, err
}
return totalSize, nil
}
func (r *BackupRepository) FindOldestByDatabaseExcludingInProgress(
databaseID uuid.UUID,
limit int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ? AND status != ?", databaseID, BackupStatusInProgress).
Order("created_at ASC").
Limit(limit).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}

View File

@@ -1,9 +1,11 @@
package backups
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/usecases"
@@ -12,6 +14,7 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/logger"
@@ -19,25 +22,26 @@ import (
var backupRepository = &backups_core.BackupRepository{}
var backupCancelManager = backups_cancellation.GetBackupCancelManager()
var taskCancelManager = task_cancellation.GetTaskCancelManager()
var backupService = &BackupService{
databaseService: databases.GetDatabaseService(),
storageService: storages.GetStorageService(),
backupRepository: backupRepository,
notifierService: notifiers.GetNotifierService(),
notificationSender: notifiers.GetNotifierService(),
backupConfigService: backups_config.GetBackupConfigService(),
secretKeyService: encryption_secrets.GetSecretKeyService(),
fieldEncryptor: encryption.GetFieldEncryptor(),
createBackupUseCase: usecases.GetCreateBackupUsecase(),
logger: logger.GetLogger(),
backupRemoveListeners: []backups_core.BackupRemoveListener{},
workspaceService: workspaces_services.GetWorkspaceService(),
auditLogService: audit_logs.GetAuditLogService(),
backupCancelManager: backupCancelManager,
downloadTokenService: backups_download.GetDownloadTokenService(),
backupSchedulerService: backuping.GetBackupsScheduler(),
databases.GetDatabaseService(),
storages.GetStorageService(),
backupRepository,
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]backups_core.BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
taskCancelManager,
backups_download.GetDownloadTokenService(),
backuping.GetBackupsScheduler(),
backuping.GetBackupCleaner(),
}
var backupController = &BackupController{
@@ -52,11 +56,26 @@ func GetBackupController() *BackupController {
return backupController
}
func SetupDependencies() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
var (
setupOnce sync.Once
isSetup atomic.Bool
)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
func SetupDependencies() {
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
backups_config.
GetBackupConfigService().
SetDatabaseStorageChangeListener(backupService)
databases.GetDatabaseService().AddDbRemoveListener(backupService)
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -2,33 +2,49 @@ package backups_download
import (
"context"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
)
type DownloadTokenBackgroundService struct {
downloadTokenService *DownloadTokenService
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *DownloadTokenBackgroundService) Run(ctx context.Context) {
s.logger.Info("Starting download token cleanup background service")
wasAlreadyRun := s.hasRun.Load()
if ctx.Err() != nil {
return
}
s.runOnce.Do(func() {
s.hasRun.Store(true)
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
s.logger.Info("Starting download token cleanup background service")
for {
select {
case <-ctx.Done():
if ctx.Err() != nil {
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if err := s.downloadTokenService.CleanExpiredTokens(); err != nil {
s.logger.Error("Failed to clean expired download tokens", "error", err)
}
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -0,0 +1,81 @@
package backups_download
import (
"fmt"
"sync"
"github.com/google/uuid"
)
type BandwidthManager struct {
mu sync.RWMutex
activeDownloads map[uuid.UUID]*activeDownload
maxTotalBytesPerSecond int64
bytesPerSecondPerDownload int64
}
type activeDownload struct {
userID uuid.UUID
rateLimiter *RateLimiter
}
func NewBandwidthManager(throughputMBs int) *BandwidthManager {
// Use 75% of total throughput
maxBytes := int64(throughputMBs) * 1024 * 1024 * 75 / 100
return &BandwidthManager{
activeDownloads: make(map[uuid.UUID]*activeDownload),
maxTotalBytesPerSecond: maxBytes,
bytesPerSecondPerDownload: maxBytes,
}
}
func (bm *BandwidthManager) RegisterDownload(userID uuid.UUID) (*RateLimiter, error) {
bm.mu.Lock()
defer bm.mu.Unlock()
if _, exists := bm.activeDownloads[userID]; exists {
return nil, fmt.Errorf("download already registered for user %s", userID)
}
rateLimiter := NewRateLimiter(bm.bytesPerSecondPerDownload)
bm.activeDownloads[userID] = &activeDownload{
userID: userID,
rateLimiter: rateLimiter,
}
bm.recalculateRates()
return rateLimiter, nil
}
func (bm *BandwidthManager) UnregisterDownload(userID uuid.UUID) {
bm.mu.Lock()
defer bm.mu.Unlock()
delete(bm.activeDownloads, userID)
bm.recalculateRates()
}
func (bm *BandwidthManager) GetActiveDownloadCount() int {
bm.mu.RLock()
defer bm.mu.RUnlock()
return len(bm.activeDownloads)
}
func (bm *BandwidthManager) recalculateRates() {
activeCount := len(bm.activeDownloads)
if activeCount == 0 {
bm.bytesPerSecondPerDownload = bm.maxTotalBytesPerSecond
return
}
newRate := bm.maxTotalBytesPerSecond / int64(activeCount)
bm.bytesPerSecondPerDownload = newRate
for _, download := range bm.activeDownloads {
download.rateLimiter.UpdateRate(newRate)
}
}

View File

@@ -0,0 +1,150 @@
package backups_download
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_BandwidthManager_RegisterSingleDownload(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
expectedBytesPerSec := int64(100 * 1024 * 1024 * 75 / 100)
assert.Equal(t, expectedBytesPerSec, manager.maxTotalBytesPerSecond)
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
userID := uuid.New()
rateLimiter, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
assert.NotNil(t, rateLimiter)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, expectedBytesPerSec, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedBytesPerSec, rateLimiter.bytesPerSecond)
}
func Test_BandwidthManager_RegisterMultipleDownloads_BandwidthShared(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, err := manager.RegisterDownload(user1)
assert.NoError(t, err)
assert.Equal(t, maxBytes, rateLimiter1.bytesPerSecond)
user2 := uuid.New()
rateLimiter2, err := manager.RegisterDownload(user2)
assert.NoError(t, err)
expectedPerDownload := maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
user3 := uuid.New()
rateLimiter3, err := manager.RegisterDownload(user3)
assert.NoError(t, err)
expectedPerDownload = maxBytes / 3
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter2.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
}
func Test_BandwidthManager_UnregisterDownload_BandwidthRebalanced(t *testing.T) {
throughputMBs := 100
manager := NewBandwidthManager(throughputMBs)
maxBytes := int64(100 * 1024 * 1024 * 75 / 100)
user1 := uuid.New()
rateLimiter1, _ := manager.RegisterDownload(user1)
user2 := uuid.New()
_, _ = manager.RegisterDownload(user2)
user3 := uuid.New()
rateLimiter3, _ := manager.RegisterDownload(user3)
assert.Equal(t, 3, manager.GetActiveDownloadCount())
expectedPerDownload := maxBytes / 3
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
manager.UnregisterDownload(user2)
assert.Equal(t, 2, manager.GetActiveDownloadCount())
expectedPerDownload = maxBytes / 2
assert.Equal(t, expectedPerDownload, manager.bytesPerSecondPerDownload)
assert.Equal(t, expectedPerDownload, rateLimiter1.bytesPerSecond)
assert.Equal(t, expectedPerDownload, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user1)
assert.Equal(t, 1, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
assert.Equal(t, maxBytes, rateLimiter3.bytesPerSecond)
manager.UnregisterDownload(user3)
assert.Equal(t, 0, manager.GetActiveDownloadCount())
assert.Equal(t, maxBytes, manager.bytesPerSecondPerDownload)
}
func Test_BandwidthManager_RegisterDuplicateUser_ReturnsError(t *testing.T) {
manager := NewBandwidthManager(100)
userID := uuid.New()
_, err := manager.RegisterDownload(userID)
assert.NoError(t, err)
_, err = manager.RegisterDownload(userID)
assert.Error(t, err)
assert.Contains(t, err.Error(), "download already registered")
}
func Test_RateLimiter_TokenBucketBasic(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
assert.Equal(t, bytesPerSec, limiter.bytesPerSecond)
assert.Equal(t, bytesPerSec*2, limiter.bucketSize)
start := time.Now()
limiter.Wait(512 * 1024)
elapsed := time.Since(start)
assert.Less(t, elapsed, 100*time.Millisecond)
}
func Test_RateLimiter_UpdateRate(t *testing.T) {
limiter := NewRateLimiter(1024 * 1024)
assert.Equal(t, int64(1024*1024), limiter.bytesPerSecond)
newRate := int64(2 * 1024 * 1024)
limiter.UpdateRate(newRate)
assert.Equal(t, newRate, limiter.bytesPerSecond)
assert.Equal(t, newRate*2, limiter.bucketSize)
}
func Test_RateLimiter_ThrottlesCorrectly(t *testing.T) {
bytesPerSec := int64(1024 * 1024)
limiter := NewRateLimiter(bytesPerSec)
limiter.availableTokens = 0
start := time.Now()
limiter.Wait(bytesPerSec / 2)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 400*time.Millisecond)
assert.LessOrEqual(t, elapsed, 700*time.Millisecond)
}

View File

@@ -1,19 +1,43 @@
package backups_download
import (
"sync"
"sync/atomic"
"databasus-backend/internal/config"
cache_utils "databasus-backend/internal/util/cache"
"databasus-backend/internal/util/logger"
)
var downloadTokenRepository = &DownloadTokenRepository{}
var downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
}
var downloadTracker = NewDownloadTracker(cache_utils.GetValkeyClient())
var downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService,
logger.GetLogger(),
var bandwidthManager *BandwidthManager
var downloadTokenService *DownloadTokenService
var downloadTokenBackgroundService *DownloadTokenBackgroundService
func init() {
env := config.GetEnv()
throughputMBs := env.NodeNetworkThroughputMBs
if throughputMBs == 0 {
throughputMBs = 125
}
bandwidthManager = NewBandwidthManager(throughputMBs)
downloadTokenService = &DownloadTokenService{
downloadTokenRepository,
logger.GetLogger(),
downloadTracker,
bandwidthManager,
}
downloadTokenBackgroundService = &DownloadTokenBackgroundService{
downloadTokenService: downloadTokenService,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
}
func GetDownloadTokenService() *DownloadTokenService {
@@ -23,3 +47,7 @@ func GetDownloadTokenService() *DownloadTokenService {
func GetDownloadTokenBackgroundService() *DownloadTokenBackgroundService {
return downloadTokenBackgroundService
}
func GetBandwidthManager() *BandwidthManager {
return bandwidthManager
}

View File

@@ -0,0 +1,101 @@
package backups_download
import (
"io"
"sync"
"time"
)
type RateLimiter struct {
mu sync.Mutex
bytesPerSecond int64
bucketSize int64
availableTokens float64
lastRefill time.Time
}
func NewRateLimiter(bytesPerSecond int64) *RateLimiter {
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
return &RateLimiter{
bytesPerSecond: bytesPerSecond,
bucketSize: bytesPerSecond * 2,
availableTokens: float64(bytesPerSecond * 2),
lastRefill: time.Now().UTC(),
}
}
func (rl *RateLimiter) UpdateRate(bytesPerSecond int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
if bytesPerSecond <= 0 {
bytesPerSecond = 1024 * 1024 * 100
}
rl.bytesPerSecond = bytesPerSecond
rl.bucketSize = bytesPerSecond * 2
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
}
func (rl *RateLimiter) Wait(bytes int64) {
rl.mu.Lock()
defer rl.mu.Unlock()
for {
now := time.Now().UTC()
elapsed := now.Sub(rl.lastRefill).Seconds()
tokensToAdd := elapsed * float64(rl.bytesPerSecond)
rl.availableTokens += tokensToAdd
if rl.availableTokens > float64(rl.bucketSize) {
rl.availableTokens = float64(rl.bucketSize)
}
rl.lastRefill = now
if rl.availableTokens >= float64(bytes) {
rl.availableTokens -= float64(bytes)
return
}
tokensNeeded := float64(bytes) - rl.availableTokens
waitTime := time.Duration(tokensNeeded/float64(rl.bytesPerSecond)*1000) * time.Millisecond
if waitTime < time.Millisecond {
waitTime = time.Millisecond
}
rl.mu.Unlock()
time.Sleep(waitTime)
rl.mu.Lock()
}
}
type RateLimitedReader struct {
reader io.ReadCloser
rateLimiter *RateLimiter
}
func NewRateLimitedReader(reader io.ReadCloser, limiter *RateLimiter) *RateLimitedReader {
return &RateLimitedReader{
reader: reader,
rateLimiter: limiter,
}
}
func (r *RateLimitedReader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
if n > 0 {
r.rateLimiter.Wait(int64(n))
}
return n, err
}
func (r *RateLimitedReader) Close() error {
return r.reader.Close()
}

View File

@@ -9,11 +9,17 @@ import (
)
type DownloadTokenService struct {
repository *DownloadTokenRepository
logger *slog.Logger
repository *DownloadTokenRepository
logger *slog.Logger
downloadTracker *DownloadTracker
bandwidthManager *BandwidthManager
}
func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, error) {
if s.downloadTracker.IsDownloadInProgress(userID) {
return "", ErrDownloadAlreadyInProgress
}
token := GenerateSecureToken()
downloadToken := &DownloadToken{
@@ -32,22 +38,34 @@ func (s *DownloadTokenService) Generate(backupID, userID uuid.UUID) (string, err
return token, nil
}
func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken, error) {
func (s *DownloadTokenService) ValidateAndConsume(
token string,
) (*DownloadToken, *RateLimiter, error) {
dt, err := s.repository.FindByToken(token)
if err != nil {
return nil, err
return nil, nil, err
}
if dt == nil {
return nil, errors.New("invalid token")
return nil, nil, errors.New("invalid token")
}
if dt.Used {
return nil, errors.New("token already used")
return nil, nil, errors.New("token already used")
}
if time.Now().UTC().After(dt.ExpiresAt) {
return nil, errors.New("token expired")
return nil, nil, errors.New("token expired")
}
if err := s.downloadTracker.AcquireDownloadLock(dt.UserID); err != nil {
return nil, nil, err
}
rateLimiter, err := s.bandwidthManager.RegisterDownload(dt.UserID)
if err != nil {
s.downloadTracker.ReleaseDownloadLock(dt.UserID)
return nil, nil, err
}
dt.Used = true
@@ -55,8 +73,26 @@ func (s *DownloadTokenService) ValidateAndConsume(token string) (*DownloadToken,
s.logger.Error("Failed to mark token as used", "error", err)
}
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID)
return dt, nil
s.logger.Info("Token validated and consumed", "backupId", dt.BackupID, "userId", dt.UserID)
return dt, rateLimiter, nil
}
func (s *DownloadTokenService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTracker.RefreshDownloadLock(userID)
}
func (s *DownloadTokenService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTracker.ReleaseDownloadLock(userID)
s.logger.Info("Released download lock", "userId", userID)
}
func (s *DownloadTokenService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTracker.IsDownloadInProgress(userID)
}
func (s *DownloadTokenService) UnregisterDownload(userID uuid.UUID) {
s.bandwidthManager.UnregisterDownload(userID)
s.logger.Info("Unregistered from bandwidth manager", "userId", userID)
}
func (s *DownloadTokenService) CleanExpiredTokens() error {

View File

@@ -0,0 +1,66 @@
package backups_download
import (
cache_utils "databasus-backend/internal/util/cache"
"errors"
"time"
"github.com/google/uuid"
"github.com/valkey-io/valkey-go"
)
const (
downloadLockPrefix = "backup_download_lock:"
downloadLockTTL = 5 * time.Second
downloadLockValue = "1"
downloadHeartbeatDelay = 3 * time.Second
)
var (
ErrDownloadAlreadyInProgress = errors.New("download already in progress for this user")
)
type DownloadTracker struct {
cache *cache_utils.CacheUtil[string]
}
func NewDownloadTracker(client valkey.Client) *DownloadTracker {
return &DownloadTracker{
cache: cache_utils.NewCacheUtil[string](client, downloadLockPrefix),
}
}
func (t *DownloadTracker) AcquireDownloadLock(userID uuid.UUID) error {
key := userID.String()
existingLock := t.cache.Get(key)
if existingLock != nil {
return ErrDownloadAlreadyInProgress
}
value := downloadLockValue
t.cache.Set(key, &value)
return nil
}
func (t *DownloadTracker) RefreshDownloadLock(userID uuid.UUID) {
key := userID.String()
value := downloadLockValue
t.cache.Set(key, &value)
}
func (t *DownloadTracker) ReleaseDownloadLock(userID uuid.UUID) {
key := userID.String()
t.cache.Invalidate(key)
}
func (t *DownloadTracker) IsDownloadInProgress(userID uuid.UUID) bool {
key := userID.String()
existingLock := t.cache.Get(key)
return existingLock != nil
}
func GetDownloadHeartbeatInterval() time.Duration {
return downloadHeartbeatDelay
}

View File

@@ -9,7 +9,6 @@ import (
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups/backuping"
backups_cancellation "databasus-backend/internal/features/backups/backups/cancellation"
backups_core "databasus-backend/internal/features/backups/backups/core"
backups_download "databasus-backend/internal/features/backups/backups/download"
"databasus-backend/internal/features/backups/backups/encryption"
@@ -18,9 +17,11 @@ import (
encryption_secrets "databasus-backend/internal/features/encryption/secrets"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
task_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
util_encryption "databasus-backend/internal/util/encryption"
files_utils "databasus-backend/internal/util/files"
"github.com/google/uuid"
)
@@ -43,9 +44,10 @@ type BackupService struct {
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupCancelManager *backups_cancellation.BackupCancelManager
taskCancelManager *task_cancellation.TaskCancelManager
downloadTokenService *backups_download.DownloadTokenService
backupSchedulerService *backuping.BackupsScheduler
backupCleaner *backuping.BackupCleaner
}
func (s *BackupService) AddBackupRemoveListener(listener backups_core.BackupRemoveListener) {
@@ -91,7 +93,7 @@ func (s *BackupService) MakeBackupWithAuth(
return errors.New("insufficient permissions to create backup for this database")
}
s.backupSchedulerService.StartBackup(databaseID, true)
s.backupSchedulerService.StartBackup(database, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
@@ -180,16 +182,12 @@ func (s *BackupService) DeleteBackup(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup deleted for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return s.deleteBackup(backup)
return s.backupCleaner.DeleteBackup(backup)
}
func (s *BackupService) GetBackup(backupID uuid.UUID) (*backups_core.Backup, error) {
@@ -226,16 +224,12 @@ func (s *BackupService) CancelBackup(
return errors.New("backup is not in progress")
}
if err := s.backupCancelManager.CancelBackup(backupID); err != nil {
if err := s.taskCancelManager.CancelTask(backupID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup cancelled for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -275,11 +269,7 @@ func (s *BackupService) GetBackupFile(
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
@@ -292,29 +282,6 @@ func (s *BackupService) GetBackupFile(
return reader, backup, database, nil
}
func (s *BackupService) deleteBackup(backup *backups_core.Backup) error {
for _, listener := range s.backupRemoveListeners {
if err := listener.OnBeforeBackupRemove(backup); err != nil {
return err
}
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return err
}
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
}
func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
dbBackupsInProgress, err := s.backupRepository.FindByDatabaseIdAndStatus(
databaseID,
@@ -336,7 +303,7 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
}
for _, dbBackup := range dbBackups {
err := s.deleteBackup(dbBackup)
err := s.backupCleaner.DeleteBackup(dbBackup)
if err != nil {
return err
}
@@ -358,7 +325,7 @@ func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, erro
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.FileName)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
@@ -481,7 +448,7 @@ func (s *BackupService) GenerateDownloadToken(
func (s *BackupService) ValidateDownloadToken(
token string,
) (*backups_download.DownloadToken, error) {
) (*backups_download.DownloadToken, *backups_download.RateLimiter, error) {
return s.downloadTokenService.ValidateAndConsume(token)
}
@@ -512,22 +479,34 @@ func (s *BackupService) WriteAuditLogForDownload(
database *databases.Database,
) {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backup.ID.String(),
),
fmt.Sprintf("Backup file downloaded for database: %s", database.Name),
&userID,
database.WorkspaceID,
)
}
func (s *BackupService) RefreshDownloadLock(userID uuid.UUID) {
s.downloadTokenService.RefreshDownloadLock(userID)
}
func (s *BackupService) ReleaseDownloadLock(userID uuid.UUID) {
s.downloadTokenService.ReleaseDownloadLock(userID)
}
func (s *BackupService) IsDownloadInProgress(userID uuid.UUID) bool {
return s.downloadTokenService.IsDownloadInProgress(userID)
}
func (s *BackupService) UnregisterDownload(userID uuid.UUID) {
s.downloadTokenService.UnregisterDownload(userID)
}
func (s *BackupService) generateBackupFilename(
backup *backups_core.Backup,
database *databases.Database,
) string {
timestamp := backup.CreatedAt.Format("2006-01-02_15-04-05")
safeName := sanitizeFilename(database.Name)
safeName := files_utils.SanitizeFilename(database.Name)
extension := s.getBackupExtension(database.Type)
return fmt.Sprintf("%s_backup_%s%s", safeName, timestamp, extension)
}

View File

@@ -75,3 +75,23 @@ func WaitForBackupCompletion(
t.Logf("WaitForBackupCompletion: timeout waiting for backup to complete")
}
// CreateTestBackup creates a simple test backup record for testing purposes
func CreateTestBackup(databaseID, storageID uuid.UUID) *backups_core.Backup {
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: databaseID,
StorageID: storageID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups_core.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
return backup
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
usecases_mariadb "databasus-backend/internal/features/backups/backups/usecases/mariadb"
usecases_mongodb "databasus-backend/internal/features/backups/backups/usecases/mongodb"
usecases_mysql "databasus-backend/internal/features/backups/backups/usecases/mysql"
@@ -12,8 +13,6 @@ import (
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/storages"
"github.com/google/uuid"
)
type CreateBackupUsecase struct {
@@ -25,7 +24,7 @@ type CreateBackupUsecase struct {
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
storage *storages.Storage,
@@ -35,7 +34,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypePostgres:
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -45,7 +44,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMysql:
return uc.CreateMysqlBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -55,7 +54,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMariadb:
return uc.CreateMariadbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,
@@ -65,7 +64,7 @@ func (uc *CreateBackupUsecase) Execute(
case databases.DatabaseTypeMongodb:
return uc.CreateMongodbBackupUsecase.Execute(
ctx,
backupID,
backup,
backupConfig,
database,
storage,

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMariadbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMariadbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMariadbExecutable(
tools.MariadbExecutableMariadbDump,
@@ -108,13 +109,15 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
"--single-transaction",
"--routines",
"--quick",
"--skip-extended-insert",
"--verbose",
}
if mdb.HasPrivilege("TRIGGER") {
args = append(args, "--triggers")
}
if mdb.HasPrivilege("EVENT") {
if mdb.HasPrivilege("EVENT") && !mdb.IsExcludeEvents {
args = append(args, "--events")
}
@@ -134,7 +137,7 @@ func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
func (uc *CreateMariadbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mariadbBin string,
args []string,
@@ -185,7 +188,7 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -202,7 +205,13 @@ func (uc *CreateMariadbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -418,7 +427,9 @@ func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -46,7 +47,7 @@ type writeResult struct {
func (uc *CreateMongodbBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -76,7 +77,7 @@ func (uc *CreateMongodbBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMongodbExecutable(
tools.MongodbExecutableMongodump,
@@ -114,7 +115,7 @@ func (uc *CreateMongodbBackupUsecase) buildMongodumpArgs(
func (uc *CreateMongodbBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mongodumpBin string,
args []string,
@@ -163,7 +164,7 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -175,7 +176,13 @@ func (uc *CreateMongodbBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -262,6 +269,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
backupMetadata := common.BackupMetadata{
BackupID: backupID,
Encryption: backups_config.BackupEncryptionNone,
}
@@ -298,6 +306,7 @@ func (uc *CreateMongodbBackupUsecase) setupBackupEncryption(
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
backupMetadata.BackupID = backupID
backupMetadata.Encryption = backups_config.BackupEncryptionEncrypted
backupMetadata.EncryptionSalt = &saltBase64
backupMetadata.EncryptionIV = &nonceBase64

View File

@@ -19,6 +19,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -52,7 +53,7 @@ type writeResult struct {
func (uc *CreateMysqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -82,7 +83,7 @@ func (uc *CreateMysqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetMysqlExecutable(
my.Version,
@@ -107,6 +108,7 @@ func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatab
"--routines",
"--set-gtid-purged=OFF",
"--quick",
"--skip-extended-insert",
"--verbose",
}
@@ -148,7 +150,7 @@ func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.Mysq
func (uc *CreateMysqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
mysqlBin string,
args []string,
@@ -199,7 +201,7 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -216,7 +218,13 @@ func (uc *CreateMysqlBackupUsecase) streamToStorage(
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -430,7 +438,9 @@ func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -16,6 +16,7 @@ import (
"databasus-backend/internal/config"
common "databasus-backend/internal/features/backups/backups/common"
backups_core "databasus-backend/internal/features/backups/backups/core"
backup_encryption "databasus-backend/internal/features/backups/backups/encryption"
backups_config "databasus-backend/internal/features/backups/config"
"databasus-backend/internal/features/databases"
@@ -53,7 +54,7 @@ type writeResult struct {
func (uc *CreatePostgresqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
storage *storages.Storage,
@@ -88,7 +89,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
return uc.streamToStorage(
ctx,
backupID,
backup,
backupConfig,
tools.GetPostgresqlExecutable(
pg.Version,
@@ -107,7 +108,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
// streamToStorage streams pg_dump output directly to storage
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backup *backups_core.Backup,
backupConfig *backups_config.BackupConfig,
pgBin string,
args []string,
@@ -166,7 +167,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storageReader, storageWriter := io.Pipe()
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backup.ID,
backupConfig,
storageWriter,
)
@@ -181,7 +182,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(
ctx,
uc.fieldEncryptor,
uc.logger,
backup.FileName,
storageReader,
)
saveErrCh <- saveErr
}()
@@ -475,7 +482,9 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, common.BackupMetadata, error) {
metadata := common.BackupMetadata{}
metadata := common.BackupMetadata{
BackupID: backupID,
}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone

View File

@@ -16,6 +16,7 @@ type BackupConfigController struct {
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backup-configs/save", c.SaveBackupConfig)
router.GET("/backup-configs/database/:id/plan", c.GetDatabasePlan)
router.GET("/backup-configs/database/:id", c.GetBackupConfigByDbID)
router.GET("/backup-configs/storage/:id/is-using", c.IsStorageUsing)
router.GET("/backup-configs/storage/:id/databases-count", c.CountDatabasesForStorage)
@@ -92,6 +93,39 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
ctx.JSON(http.StatusOK, backupConfig)
}
// GetDatabasePlan
// @Summary Get database plan by database ID
// @Description Get the plan limits for a specific database (max backup size, max total size, max storage period)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} plans.DatabasePlan
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Database not found or access denied"
// @Router /backup-configs/database/{id}/plan [get]
func (c *BackupConfigController) GetDatabasePlan(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
plan, err := c.backupConfigService.GetDatabasePlan(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "database plan not found"})
return
}
ctx.JSON(http.StatusOK, plan)
}
// IsStorageUsing
// @Summary Check if storage is being used
// @Description Check if a storage is currently being used by any backup configuration

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"strconv"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -16,11 +17,14 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
users_enums "databasus-backend/internal/features/users/enums"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/period"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -89,6 +93,11 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -109,9 +118,10 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -137,7 +147,7 @@ func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, period.PeriodWeek, response.RetentionTimePeriod)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
@@ -152,13 +162,19 @@ func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *test
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -242,6 +258,11 @@ func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -290,6 +311,11 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
@@ -300,14 +326,218 @@ func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T)
&response,
)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, plan.MaxStoragePeriod, response.RetentionTimePeriod)
assert.Equal(t, plan.MaxBackupSizeMB, response.MaxBackupSizeMB)
assert.Equal(t, plan.MaxBackupsTotalSizeMB, response.MaxBackupsTotalSizeMB)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_GetDatabasePlan_ForNewDatabase_PlanAlwaysReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var response plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.MaxBackupSizeMB)
assert.NotNil(t, response.MaxBackupsTotalSizeMB)
assert.NotEmpty(t, response.MaxStoragePeriod)
}
func Test_SaveBackupConfig_WhenPlanLimitsAreAdjusted_ValidationEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
// Get plan via API (triggers auto-creation)
var plan plans.DatabasePlan
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String()+"/plan",
"Bearer "+owner.Token,
http.StatusOK,
&plan,
)
assert.Equal(t, database.ID, plan.DatabaseID)
// Adjust plan limits directly in database to fixed restrictive values
err := storage.GetDb().Model(&plans.DatabasePlan{}).
Where("database_id = ?", database.ID).
Updates(map[string]any{
"max_backup_size_mb": 100,
"max_backups_total_size_mb": 1000,
"max_storage_period": period.PeriodMonth,
}).Error
assert.NoError(t, err)
// Test 1: Try to save backup config with exceeded backup size limit
timeOfDay := "04:00"
backupConfigExceededSize := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 200, // Exceeds limit of 100
MaxBackupsTotalSizeMB: 800,
}
respExceededSize := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededSize,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededSize.Body), "max backup size exceeds plan limit")
// Test 2: Try to save backup config with exceeded total size limit
backupConfigExceededTotal := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 50,
MaxBackupsTotalSizeMB: 2000, // Exceeds limit of 1000
}
respExceededTotal := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededTotal,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededTotal.Body), "max total backups size exceeds plan limit")
// Test 3: Try to save backup config with exceeded storage period limit
backupConfigExceededPeriod := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodYear, // Exceeds limit of Month
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80,
MaxBackupsTotalSizeMB: 800,
}
respExceededPeriod := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigExceededPeriod,
http.StatusBadRequest,
)
assert.Contains(t, string(respExceededPeriod.Body), "storage period exceeds plan limit")
// Test 4: Save backup config within all limits - should succeed
backupConfigValid := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek, // Within Month limit
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 80, // Within 100 limit
MaxBackupsTotalSizeMB: 800, // Within 1000 limit
}
var responseValid BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
backupConfigValid,
http.StatusOK,
&responseValid,
)
assert.Equal(t, database.ID, responseValid.DatabaseID)
assert.Equal(t, int64(80), responseValid.MaxBackupSizeMB)
assert.Equal(t, int64(800), responseValid.MaxBackupsTotalSizeMB)
assert.Equal(t, period.PeriodWeek, responseValid.RetentionTimePeriod)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -340,6 +570,10 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
storage := createTestStorage(workspace.ID)
defer func() {
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
var testUserToken string
if tt.isStorageOwner {
testUserToken = storageOwner.Token
@@ -372,10 +606,6 @@ func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "error")
}
// Cleanup
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -387,11 +617,17 @@ func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -426,11 +662,17 @@ func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -536,6 +778,15 @@ func Test_TransferDatabase_PermissionsEnforced(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
// Cleanup in correct order to avoid foreign key violations
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascade delete of backup_config
storages.RemoveTestStorage(targetStorage.ID)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
@@ -628,6 +879,12 @@ func Test_TransferDatabase_NonMemberInSourceWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -668,6 +925,12 @@ func Test_TransferDatabase_NonMemberInTargetWorkspace_CannotTransfer(t *testing.
router,
)
defer func() {
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
request := TransferDatabaseRequest{
TargetWorkspaceID: targetWorkspace.ID,
}
@@ -695,11 +958,19 @@ func Test_TransferDatabase_ToNewStorage_DatabaseTransferd(t *testing.T) {
sourceStorage := createTestStorage(sourceWorkspace.ID)
targetStorage := createTestStorage(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -774,11 +1045,19 @@ func Test_TransferDatabase_WithExistingStorage_DatabaseAndStorageTransferd(t *te
database := createTestDatabaseViaAPI("Test Database", sourceWorkspace.ID, owner.Token, router)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -863,11 +1142,20 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
)
storage := createTestStorage(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond) // Wait for cascading deletes
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest1 := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -891,9 +1179,10 @@ func Test_TransferDatabase_StorageHasOtherDBs_CannotTransfer(t *testing.T) {
)
backupConfigRequest2 := BackupConfig{
DatabaseID: database2.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database2.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -945,6 +1234,14 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
notifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var updatedDatabase databases.Database
test_utils.MakePostRequestAndUnmarshal(
@@ -959,9 +1256,10 @@ func Test_TransferDatabase_WithNotifiers_NotifiersTransferred(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1048,6 +1346,15 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1070,9 +1377,10 @@ func Test_TransferDatabase_NotifierHasOtherDBs_NotifierSkipped(t *testing.T) {
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1160,6 +1468,16 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
exclusiveNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
sharedNotifier := notifiers.CreateTestNotifier(sourceWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database1)
databases.RemoveTestDatabase(database2)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(exclusiveNotifier)
notifiers.RemoveTestNotifier(sharedNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database1.Notifiers = []notifiers.Notifier{*exclusiveNotifier, *sharedNotifier}
test_utils.MakePostRequest(
t,
@@ -1182,9 +1500,10 @@ func Test_TransferDatabase_WithMultipleNotifiers_OnlyExclusiveOnesTransferred(t
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database1.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database1.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1271,11 +1590,20 @@ func Test_TransferDatabase_WithTargetNotifiers_NotifiersAssigned(t *testing.T) {
targetStorage := createTestStorage(targetWorkspace.ID)
targetNotifier := notifiers.CreateTestNotifier(targetWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(targetNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1342,11 +1670,21 @@ func Test_TransferDatabase_TargetNotifierFromDifferentWorkspace_ReturnsBadReques
targetStorage := createTestStorage(targetWorkspace.ID)
wrongNotifier := notifiers.CreateTestNotifier(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
notifiers.RemoveTestNotifier(wrongNotifier)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1399,11 +1737,20 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
sourceStorage := createTestStorage(sourceWorkspace.ID)
wrongStorage := createTestStorage(otherWorkspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
time.Sleep(200 * time.Millisecond)
workspaces_testing.RemoveTestWorkspace(sourceWorkspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
workspaces_testing.RemoveTestWorkspace(otherWorkspace, router)
}()
timeOfDay := "04:00"
backupConfigRequest := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -1443,6 +1790,117 @@ func Test_TransferDatabase_TargetStorageFromDifferentWorkspace_ReturnsBadRequest
assert.Contains(t, string(testResp.Body), "target storage does not belong to target workspace")
}
func Test_SaveBackupConfig_WithSystemStorage_CanBeUsedByAnyDatabase(t *testing.T) {
router := createTestRouterWithStorageForTransfer()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
workspaceA := workspaces_testing.CreateTestWorkspace("Workspace A", owner1, router)
workspaceB := workspaces_testing.CreateTestWorkspace("Workspace B", owner2, router)
databaseA := createTestDatabaseViaAPI("Database A", workspaceA.ID, owner1.Token, router)
// Test 1: Regular storage from workspace B cannot be used by database in workspace A
regularStorageB := createTestStorage(workspaceB.ID)
timeOfDay := "04:00"
backupConfigWithRegularStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &regularStorageB.ID,
Storage: regularStorageB,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
respRegular := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithRegularStorage,
http.StatusBadRequest,
)
assert.Contains(t, string(respRegular.Body), "storage does not belong to the same workspace")
// Test 2: System storage from workspace B CAN be used by database in workspace A
systemStorageB := &storages.Storage{
WorkspaceID: workspaceB.ID,
Type: storages.StorageTypeLocal,
Name: "Test System Storage " + uuid.New().String(),
IsSystem: true,
LocalStorage: &local_storage.LocalStorage{},
}
var savedSystemStorage storages.Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*systemStorageB,
http.StatusOK,
&savedSystemStorage,
)
assert.True(t, savedSystemStorage.IsSystem)
backupConfigWithSystemStorage := BackupConfig{
DatabaseID: databaseA.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
StorageID: &savedSystemStorage.ID,
Storage: &savedSystemStorage,
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var savedConfig BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner1.Token,
backupConfigWithSystemStorage,
http.StatusOK,
&savedConfig,
)
assert.Equal(t, databaseA.ID, savedConfig.DatabaseID)
assert.NotNil(t, savedConfig.StorageID)
assert.Equal(t, savedSystemStorage.ID, *savedConfig.StorageID)
assert.True(t, savedConfig.IsBackupsEnabled)
// Cleanup: database first (cascades to backup_config), then storages, then workspaces
databases.RemoveTestDatabase(databaseA)
storages.RemoveTestStorage(regularStorageB.ID)
storages.RemoveTestStorage(savedSystemStorage.ID)
workspaces_testing.RemoveTestWorkspace(workspaceA, router)
workspaces_testing.RemoveTestWorkspace(workspaceB, router)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
@@ -1462,7 +1920,7 @@ func createTestDatabaseViaAPI(
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",

View File

@@ -1,10 +1,15 @@
package backups_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/logger"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -14,6 +19,7 @@ var backupConfigService = &BackupConfigService{
storages.GetStorageService(),
notifiers.GetNotifierService(),
workspaces_services.GetWorkspaceService(),
plans.GetDatabasePlanService(),
nil,
}
var backupConfigController = &BackupConfigController{
@@ -28,6 +34,21 @@ func GetBackupConfigService() *BackupConfigService {
return backupConfigService
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
storages.GetStorageService().SetStorageDatabaseCounter(backupConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -13,3 +13,11 @@ const (
BackupEncryptionNone BackupEncryption = "NONE"
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
)
type RetentionPolicyType string
const (
RetentionPolicyTypeTimePeriod RetentionPolicyType = "TIME_PERIOD"
RetentionPolicyTypeCount RetentionPolicyType = "COUNT"
RetentionPolicyTypeGFS RetentionPolicyType = "GFS"
)

View File

@@ -1,7 +1,9 @@
package backups_config
import (
"databasus-backend/internal/config"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/util/period"
"errors"
@@ -16,7 +18,15 @@ type BackupConfig struct {
IsBackupsEnabled bool `json:"isBackupsEnabled" gorm:"column:is_backups_enabled;type:boolean;not null"`
StorePeriod period.Period `json:"storePeriod" gorm:"column:store_period;type:text;not null"`
RetentionPolicyType RetentionPolicyType `json:"retentionPolicyType" gorm:"column:retention_policy_type;type:text;not null;default:'TIME_PERIOD'"`
RetentionTimePeriod period.TimePeriod `json:"retentionTimePeriod" gorm:"column:retention_time_period;type:text;not null;default:''"`
RetentionCount int `json:"retentionCount" gorm:"column:retention_count;type:int;not null;default:0"`
RetentionGfsHours int `json:"retentionGfsHours" gorm:"column:retention_gfs_hours;type:int;not null;default:0"`
RetentionGfsDays int `json:"retentionGfsDays" gorm:"column:retention_gfs_days;type:int;not null;default:0"`
RetentionGfsWeeks int `json:"retentionGfsWeeks" gorm:"column:retention_gfs_weeks;type:int;not null;default:0"`
RetentionGfsMonths int `json:"retentionGfsMonths" gorm:"column:retention_gfs_months;type:int;not null;default:0"`
RetentionGfsYears int `json:"retentionGfsYears" gorm:"column:retention_gfs_years;type:int;not null;default:0"`
BackupIntervalID uuid.UUID `json:"backupIntervalId" gorm:"column:backup_interval_id;type:uuid;not null"`
BackupInterval *intervals.Interval `json:"backupInterval,omitempty" gorm:"foreignKey:BackupIntervalID"`
@@ -31,6 +41,11 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
// MaxBackupSizeMB limits individual backup size. 0 = unlimited.
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
// MaxBackupsTotalSizeMB limits total size of all backups. 0 = unlimited.
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
}
func (h *BackupConfig) TableName() string {
@@ -70,14 +85,13 @@ func (b *BackupConfig) AfterFind(tx *gorm.DB) error {
return nil
}
func (b *BackupConfig) Validate() error {
// Backup interval is required either as ID or as object
func (b *BackupConfig) Validate(plan *plans.DatabasePlan) error {
if b.BackupIntervalID == uuid.Nil && b.BackupInterval == nil {
return errors.New("backup interval is required")
}
if b.StorePeriod == "" {
return errors.New("store period is required")
if err := b.validateRetentionPolicy(plan); err != nil {
return err
}
if b.IsRetryIfFailed && b.MaxFailedTriesCount <= 0 {
@@ -89,20 +103,87 @@ func (b *BackupConfig) Validate() error {
return errors.New("encryption must be NONE or ENCRYPTED")
}
if config.GetEnv().IsCloud {
if b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption is mandatory for cloud storage")
}
}
if b.MaxBackupSizeMB < 0 {
return errors.New("max backup size must be non-negative")
}
if b.MaxBackupsTotalSizeMB < 0 {
return errors.New("max backups total size must be non-negative")
}
if plan.MaxBackupSizeMB > 0 {
if b.MaxBackupSizeMB == 0 || b.MaxBackupSizeMB > plan.MaxBackupSizeMB {
return errors.New("max backup size exceeds plan limit")
}
}
if plan.MaxBackupsTotalSizeMB > 0 {
if b.MaxBackupsTotalSizeMB == 0 ||
b.MaxBackupsTotalSizeMB > plan.MaxBackupsTotalSizeMB {
return errors.New("max total backups size exceeds plan limit")
}
}
return nil
}
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
return &BackupConfig{
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
StorePeriod: b.StorePeriod,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
DatabaseID: newDatabaseID,
IsBackupsEnabled: b.IsBackupsEnabled,
RetentionPolicyType: b.RetentionPolicyType,
RetentionTimePeriod: b.RetentionTimePeriod,
RetentionCount: b.RetentionCount,
RetentionGfsHours: b.RetentionGfsHours,
RetentionGfsDays: b.RetentionGfsDays,
RetentionGfsWeeks: b.RetentionGfsWeeks,
RetentionGfsMonths: b.RetentionGfsMonths,
RetentionGfsYears: b.RetentionGfsYears,
BackupIntervalID: uuid.Nil,
BackupInterval: b.BackupInterval.Copy(),
StorageID: b.StorageID,
SendNotificationsOn: b.SendNotificationsOn,
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
Encryption: b.Encryption,
MaxBackupSizeMB: b.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: b.MaxBackupsTotalSizeMB,
}
}
func (b *BackupConfig) validateRetentionPolicy(plan *plans.DatabasePlan) error {
switch b.RetentionPolicyType {
case RetentionPolicyTypeTimePeriod, "":
if b.RetentionTimePeriod == "" {
return errors.New("retention time period is required")
}
if plan.MaxStoragePeriod != period.PeriodForever {
if b.RetentionTimePeriod.CompareTo(plan.MaxStoragePeriod) > 0 {
return errors.New("storage period exceeds plan limit")
}
}
case RetentionPolicyTypeCount:
if b.RetentionCount <= 0 {
return errors.New("retention count must be greater than 0")
}
case RetentionPolicyTypeGFS:
if b.RetentionGfsHours <= 0 && b.RetentionGfsDays <= 0 && b.RetentionGfsWeeks <= 0 &&
b.RetentionGfsMonths <= 0 && b.RetentionGfsYears <= 0 {
return errors.New("at least one GFS retention field must be greater than 0")
}
default:
return errors.New("invalid retention policy type")
}
return nil
}

View File

@@ -0,0 +1,477 @@
package backups_config
import (
"testing"
"databasus-backend/internal/features/intervals"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_Validate_WhenRetentionTimePeriodIsWeekAndPlanAllowsMonth_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodWeek
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsYearAndPlanAllowsMonth_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsForever_ValidationPasses(
t *testing.T,
) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodForever
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenRetentionTimePeriodIsForeverAndPlanAllowsYear_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodYear
err := config.Validate(plan)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenRetentionTimePeriodEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodMonth
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize100MBAndPlanAllows500MB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 100
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSize500MBAndPlanAllows100MB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 100
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenBackupSizeIsUnlimitedAndPlanHas500MBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size exceeds plan limit")
}
func Test_Validate_WhenBackupSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = 500
plan := createUnlimitedPlan()
plan.MaxBackupSizeMB = 500
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize1GBAndPlanAllows5GB_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 1000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSize5GBAndPlanAllows1GB_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanAllowsUnlimited_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 0
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenTotalSizeIsUnlimitedAndPlanHas1GBLimit_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.EqualError(t, err, "max total backups size exceeds plan limit")
}
func Test_Validate_WhenTotalSizeEqualsExactPlanLimit_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxBackupsTotalSizeMB = 5000
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenAllLimitsAreUnlimitedInPlan_AnyConfigurationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodForever
config.MaxBackupSizeMB = 0
config.MaxBackupsTotalSizeMB = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenMultipleLimitsExceeded_ValidationFailsWithFirstError(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = period.PeriodYear
config.MaxBackupSizeMB = 500
config.MaxBackupsTotalSizeMB = 5000
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = period.PeriodMonth
plan.MaxBackupSizeMB = 100
plan.MaxBackupsTotalSizeMB = 1000
err := config.Validate(plan)
assert.Error(t, err)
assert.EqualError(t, err, "storage period exceeds plan limit")
}
func Test_Validate_WhenConfigHasInvalidIntervalButPlanIsValid_ValidationFailsOnInterval(
t *testing.T,
) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenIntervalIsMissing_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.BackupIntervalID = uuid.Nil
config.BackupInterval = nil
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "backup interval is required")
}
func Test_Validate_WhenRetryEnabledButMaxTriesIsZero_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.IsRetryIfFailed = true
config.MaxFailedTriesCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max failed tries count must be greater than 0")
}
func Test_Validate_WhenEncryptionIsInvalid_ValidationFailsRegardlessOfPlan(t *testing.T) {
config := createValidBackupConfig()
config.Encryption = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "encryption must be NONE or ENCRYPTED")
}
func Test_Validate_WhenRetentionTimePeriodIsEmpty_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = ""
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "retention time period is required")
}
func Test_Validate_WhenMaxBackupSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupSizeMB = -100
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backup size must be non-negative")
}
func Test_Validate_WhenMaxTotalSizeIsNegative_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.MaxBackupsTotalSizeMB = -1000
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "max backups total size must be non-negative")
}
func Test_Validate_WhenPlanLimitsAreAtBoundary_ValidationWorks(t *testing.T) {
tests := []struct {
name string
configPeriod period.TimePeriod
planPeriod period.TimePeriod
configSize int64
planSize int64
configTotal int64
planTotal int64
shouldSucceed bool
}{
{
name: "all values just under limit",
configPeriod: period.PeriodWeek,
planPeriod: period.PeriodMonth,
configSize: 99,
planSize: 100,
configTotal: 999,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "all values equal to limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: true,
},
{
name: "period just over limit",
configPeriod: period.Period3Month,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 101,
planSize: 100,
configTotal: 1000,
planTotal: 1000,
shouldSucceed: false,
},
{
name: "total size just over limit",
configPeriod: period.PeriodMonth,
planPeriod: period.PeriodMonth,
configSize: 100,
planSize: 100,
configTotal: 1001,
planTotal: 1000,
shouldSucceed: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := createValidBackupConfig()
config.RetentionTimePeriod = tt.configPeriod
config.MaxBackupSizeMB = tt.configSize
config.MaxBackupsTotalSizeMB = tt.configTotal
plan := createUnlimitedPlan()
plan.MaxStoragePeriod = tt.planPeriod
plan.MaxBackupSizeMB = tt.planSize
plan.MaxBackupsTotalSizeMB = tt.planTotal
err := config.Validate(plan)
if tt.shouldSucceed {
assert.NoError(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_Validate_WhenPolicyTypeIsCount_RequiresPositiveCount(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "retention count must be greater than 0")
}
func Test_Validate_WhenPolicyTypeIsCount_WithPositiveCount_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeCount
config.RetentionCount = 10
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_RequiresAtLeastOneField(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 0
config.RetentionGfsWeeks = 0
config.RetentionGfsMonths = 0
config.RetentionGfsYears = 0
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "at least one GFS retention field must be greater than 0")
}
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyHours_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_WithOnlyDays_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsDays = 7
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsGFS_WithAllFields_ValidationPasses(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = RetentionPolicyTypeGFS
config.RetentionGfsHours = 24
config.RetentionGfsDays = 7
config.RetentionGfsWeeks = 4
config.RetentionGfsMonths = 12
config.RetentionGfsYears = 3
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.NoError(t, err)
}
func Test_Validate_WhenPolicyTypeIsInvalid_ValidationFails(t *testing.T) {
config := createValidBackupConfig()
config.RetentionPolicyType = "INVALID"
plan := createUnlimitedPlan()
err := config.Validate(plan)
assert.EqualError(t, err, "invalid retention policy type")
}
func createValidBackupConfig() *BackupConfig {
intervalID := uuid.New()
return &BackupConfig{
DatabaseID: uuid.New(),
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodMonth,
BackupIntervalID: intervalID,
BackupInterval: &intervals.Interval{ID: intervalID},
SendNotificationsOn: []BackupNotificationType{},
IsRetryIfFailed: false,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
MaxBackupSizeMB: 100,
MaxBackupsTotalSizeMB: 1000,
}
}
func createUnlimitedPlan() *plans.DatabasePlan {
return &plans.DatabasePlan{
DatabaseID: uuid.New(),
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}

View File

@@ -26,6 +26,12 @@ func Test_AttachNotifierFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -55,6 +61,13 @@ func Test_AttachNotifierFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
notifier := notifiers.CreateTestNotifier(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
testResp := test_utils.MakePostRequest(
@@ -77,6 +90,12 @@ func Test_DeleteNotifierWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database
@@ -114,6 +133,13 @@ func Test_TransferNotifierWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
database.Notifiers = []notifiers.Notifier{*notifier}
var response databases.Database

View File

@@ -6,10 +6,10 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/intervals"
"databasus-backend/internal/features/notifiers"
plans "databasus-backend/internal/features/plan"
"databasus-backend/internal/features/storages"
users_models "databasus-backend/internal/features/users/models"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
@@ -20,6 +20,7 @@ type BackupConfigService struct {
storageService *storages.StorageService
notifierService *notifiers.NotifierService
workspaceService *workspaces_services.WorkspaceService
databasePlanService *plans.DatabasePlanService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -45,7 +46,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
user *users_models.User,
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -71,7 +77,7 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
if err != nil {
return nil, err
}
if storage.WorkspaceID != *database.WorkspaceID {
if storage.WorkspaceID != *database.WorkspaceID && !storage.IsSystem {
return nil, errors.New("storage does not belong to the same workspace as the database")
}
}
@@ -82,7 +88,12 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
func (s *BackupConfigService) SaveBackupConfig(
backupConfig *BackupConfig,
) (*BackupConfig, error) {
if err := backupConfig.Validate(); err != nil {
plan, err := s.databasePlanService.GetDatabasePlan(backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if err := backupConfig.Validate(plan); err != nil {
return nil, err
}
@@ -120,6 +131,18 @@ func (s *BackupConfigService) GetBackupConfigByDbIdWithAuth(
return s.GetBackupConfigByDbId(databaseID)
}
func (s *BackupConfigService) GetDatabasePlan(
user *users_models.User,
databaseID uuid.UUID,
) (*plans.DatabasePlan, error) {
_, err := s.databaseService.GetDatabase(user, databaseID)
if err != nil {
return nil, err
}
return s.databasePlanService.GetDatabasePlan(databaseID)
}
func (s *BackupConfigService) GetBackupConfigByDbId(
databaseID uuid.UUID,
) (*BackupConfig, error) {
@@ -194,12 +217,20 @@ func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) e
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
plan, err := s.databasePlanService.GetDatabasePlan(databaseID)
if err != nil {
return err
}
timeOfDay := "04:00"
_, err := s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
StorePeriod: period.PeriodWeek,
_, err = s.backupConfigRepository.Save(&BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: false,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: plan.MaxStoragePeriod,
MaxBackupSizeMB: plan.MaxBackupSizeMB,
MaxBackupsTotalSizeMB: plan.MaxBackupsTotalSizeMB,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -27,11 +27,18 @@ func Test_AttachStorageFromSameWorkspace_SuccessfullyAttached(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -72,11 +79,19 @@ func Test_AttachStorageFromDifferentWorkspace_ReturnsForbidden(t *testing.T) {
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage := createTestStorage(workspace2.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -110,11 +125,18 @@ func Test_DeleteStorageWithAttachedDatabases_CannotDelete(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
@@ -163,11 +185,19 @@ func Test_TransferStorageWithAttachedDatabase_CannotTransfer(t *testing.T) {
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
defer func() {
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
workspaces_testing.RemoveTestWorkspace(targetWorkspace, router)
}()
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
DatabaseID: database.ID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -15,9 +15,10 @@ func EnableBackupsForTestDatabase(
timeOfDay := "16:00"
backupConfig := &BackupConfig{
DatabaseID: databaseID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodDay,
DatabaseID: databaseID,
IsBackupsEnabled: true,
RetentionPolicyType: RetentionPolicyTypeTimePeriod,
RetentionTimePeriod: period.PeriodDay,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,

View File

@@ -25,80 +25,6 @@ import (
"databasus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
CpuCount: 1,
}
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
@@ -142,6 +68,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var testUserToken string
if tt.isGlobalAdmin {
@@ -180,6 +107,7 @@ func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.Equal(t, "Test Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
} else {
@@ -193,6 +121,7 @@ func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -258,8 +187,10 @@ func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -305,8 +236,10 @@ func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testin
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
database.Name = "Hacked Name"
@@ -366,6 +299,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
@@ -396,6 +330,7 @@ func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
)
if !tt.expectSuccess {
defer RemoveTestDatabase(database)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
@@ -439,8 +374,10 @@ func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
@@ -517,9 +454,12 @@ func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
var testUser string
if tt.isGlobalAdmin {
@@ -561,10 +501,14 @@ func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
db1 := createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db1)
db2 := createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db2)
db3 := createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(db3)
var response []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -583,14 +527,19 @@ func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
defer workspaces_testing.RemoveTestWorkspace(workspace1, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
defer workspaces_testing.RemoveTestWorkspace(workspace2, router)
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
workspace1Db1 := createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db1)
workspace1Db2 := createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
defer RemoveTestDatabase(workspace1Db2)
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
workspace2Db1 := createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
defer RemoveTestDatabase(workspace2Db1)
var workspace1Dbs []Database
test_utils.MakeGetRequestAndUnmarshal(
@@ -667,8 +616,10 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUserToken string
if tt.isGlobalAdmin {
@@ -700,6 +651,7 @@ func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
)
if tt.expectSuccess {
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Contains(t, response.Name, "(Copy)")
} else {
@@ -713,8 +665,10 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var response Database
test_utils.MakePostRequestAndUnmarshal(
@@ -727,139 +681,14 @@ func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
&response,
)
defer RemoveTestDatabase(&response)
assert.NotEqual(t, database.ID, response.ID)
assert.Equal(t, "Test Database (Copy)", response.Name)
assert.Equal(t, workspace.ID, *response.WorkspaceID)
assert.Equal(t, database.Type, response.Type)
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func Test_CreateDatabase_PasswordIsEncryptedInDB(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -1141,3 +970,207 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
})
}
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
defer RemoveTestDatabase(database)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func getTestPostgresConfig() *postgresql.PostgresqlDatabase {
env := config.GetEnv()
port, err := strconv.Atoi(env.TestPostgres16Port)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_POSTGRES_16_PORT: %v", err))
}
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
CpuCount: 1,
}
}
func getTestMariadbConfig() *mariadb.MariadbDatabase {
env := config.GetEnv()
portStr := env.TestMariadb1011Port
if portStr == "" {
portStr = "33111"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MARIADB_1011_PORT: %v", err))
}
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
Database: &testDbName,
}
}
func getTestMongodbConfig() *mongodb.MongodbDatabase {
env := config.GetEnv()
portStr := env.TestMongodb70Port
if portStr == "" {
portStr = "27070"
}
port, err := strconv.Atoi(portStr)
if err != nil {
panic(fmt.Sprintf("Failed to parse TEST_MONGODB_70_PORT: %v", err))
}
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: config.GetEnv().TestLocalhost,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}

View File

@@ -25,13 +25,14 @@ type MariadbDatabase struct {
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsExcludeEvents bool `json:"isExcludeEvents" gorm:"type:boolean;default:false"`
Privileges string `json:"privileges" gorm:"column:privileges;type:text;not null;default:''"`
}
func (m *MariadbDatabase) TableName() string {
@@ -124,6 +125,7 @@ func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
m.IsExcludeEvents = incoming.IsExcludeEvents
m.Privileges = incoming.Privileges
if incoming.Password != "" {
@@ -515,9 +517,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MariaDB's grant output format
// MariaDB escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -694,6 +694,115 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MariadbVersion
port string
}{
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMariadbContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer dropUserSafe(underscoreDB, allPrivUsername)
mariadbModel := &MariadbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mariadbModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mariadbModel.Privileges)
assert.Contains(t, mariadbModel.Privileges, "SELECT")
assert.Contains(t, mariadbModel.Privileges, "SHOW VIEW")
})
}
}
type MariadbContainer struct {
Host string
Port int
@@ -714,7 +823,7 @@ func connectToMariadbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -25,14 +25,16 @@ type MongodbDatabase struct {
Version tools.MongodbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
Host string `json:"host" gorm:"type:text;not null"`
Port *int `json:"port" gorm:"type:int"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database string `json:"database" gorm:"type:text;not null"`
AuthDatabase string `json:"authDatabase" gorm:"type:text;not null;default:'admin'"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
IsSrv bool `json:"isSrv" gorm:"column:is_srv;type:boolean;not null;default:false"`
IsDirectConnection bool `json:"isDirectConnection" gorm:"column:is_direct_connection;type:boolean;not null;default:false"`
CpuCount int `json:"cpuCount" gorm:"column:cpu_count;type:int;not null;default:1"`
}
func (m *MongodbDatabase) TableName() string {
@@ -43,9 +45,13 @@ func (m *MongodbDatabase) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Port == 0 {
return errors.New("port is required")
if !m.IsSrv {
if m.Port == nil || *m.Port == 0 {
return errors.New("port is required for standard connections")
}
}
if m.Username == "" {
return errors.New("username is required")
}
@@ -58,6 +64,7 @@ func (m *MongodbDatabase) Validate() error {
if m.CpuCount <= 0 {
return errors.New("cpu count must be greater than 0")
}
return nil
}
@@ -125,6 +132,8 @@ func (m *MongodbDatabase) Update(incoming *MongodbDatabase) {
m.Database = incoming.Database
m.AuthDatabase = incoming.AuthDatabase
m.IsHttps = incoming.IsHttps
m.IsSrv = incoming.IsSrv
m.IsDirectConnection = incoming.IsDirectConnection
m.CpuCount = incoming.CpuCount
if incoming.Password != "" {
@@ -450,9 +459,29 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
authDB = "admin"
}
tlsParams := ""
extraParams := ""
if m.IsHttps {
tlsParams = "&tls=true&tlsInsecure=true"
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/%s?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Database,
authDB,
extraParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
@@ -460,10 +489,10 @@ func (m *MongodbDatabase) buildConnectionURI(password string) string {
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
m.Database,
authDB,
tlsParams,
extraParams,
)
}
@@ -474,9 +503,28 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
authDB = "admin"
}
tlsParams := ""
extraParams := ""
if m.IsHttps {
tlsParams = "&tls=true&tlsInsecure=true"
extraParams += "&tls=true&tlsInsecure=true"
}
if m.IsDirectConnection {
extraParams += "&directConnection=true"
}
if m.IsSrv {
return fmt.Sprintf(
"mongodb+srv://%s:%s@%s/?authSource=%s&connectTimeoutMS=15000%s",
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
authDB,
extraParams,
)
}
port := 27017
if m.Port != nil {
port = *m.Port
}
return fmt.Sprintf(
@@ -484,9 +532,9 @@ func (m *MongodbDatabase) BuildMongodumpURI(password string) string {
url.QueryEscape(m.Username),
url.QueryEscape(password),
m.Host,
m.Port,
port,
authDB,
tlsParams,
extraParams,
)
}

View File

@@ -9,6 +9,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -63,15 +64,17 @@ func Test_TestConnection_InsufficientPermissions_ReturnsError(t *testing.T) {
defer dropUserSafe(container.Client, limitedUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: limitedUsername,
Password: limitedPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -132,15 +135,17 @@ func Test_TestConnection_SufficientPermissions_Success(t *testing.T) {
defer dropUserSafe(container.Client, backupUsername, container.AuthDatabase)
port := container.Port
mongodbModel := &MongodbDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: backupUsername,
Password: backupPassword,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
@@ -397,7 +402,7 @@ func connectToMongodbContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"
authDatabase := "admin"
@@ -406,11 +411,18 @@ func connectToMongodbContainer(
assert.NoError(t, err)
uri := fmt.Sprintf(
"mongodb://%s:%s@%s:%d/%s?authSource=%s",
username, password, host, portInt, dbName, authDatabase,
"mongodb://%s:%s@%s:%d/%s?authSource=%s&serverSelectionTimeoutMS=5000&connectTimeoutMS=5000",
username,
password,
host,
portInt,
dbName,
authDatabase,
)
ctx := context.Background()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
clientOptions := options.Client().ApplyURI(uri)
client, err := mongo.Connect(ctx, clientOptions)
if err != nil {
@@ -434,15 +446,17 @@ func connectToMongodbContainer(
}
func createMongodbModel(container *MongodbContainer) *MongodbDatabase {
port := container.Port
return &MongodbDatabase{
Version: container.Version,
Host: container.Host,
Port: container.Port,
Port: &port,
Username: container.Username,
Password: container.Password,
Database: container.Database,
AuthDatabase: container.AuthDatabase,
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}
@@ -481,3 +495,240 @@ func assertWriteDenied(t *testing.T, err error) {
strings.Contains(errStr, "permission denied"),
"Expected authorization error, got: %v", err)
}
func Test_BuildConnectionURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
}
func Test_BuildConnectionURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/mydb")
assert.Contains(t, uri, "authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
}
func Test_BuildConnectionURI_WithNullPort_UsesDefault(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "localhost:27017")
}
func Test_BuildMongodumpURI_WithSrvFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb+srv://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "cluster0.example.mongodb.net")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, ":27017")
assert.NotContains(t, uri, "/mydb")
}
func Test_BuildMongodumpURI_WithStandardFormat_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "testuser")
assert.Contains(t, uri, "testpass123")
assert.Contains(t, uri, "localhost:27017")
assert.Contains(t, uri, "/?authSource=admin")
assert.Contains(t, uri, "connectTimeoutMS=15000")
assert.NotContains(t, uri, "mongodb+srv://")
assert.NotContains(t, uri, "/mydb")
}
func Test_Validate_SrvConnection_AllowsNullPort(t *testing.T) {
model := &MongodbDatabase{
Host: "cluster0.example.mongodb.net",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: true,
CpuCount: 1,
}
err := model.Validate()
assert.NoError(t, err)
}
func Test_BuildConnectionURI_WithDirectConnection_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "directConnection=true")
assert.Contains(t, uri, "mongo.example.local:27017")
assert.Contains(t, uri, "authSource=admin")
}
func Test_BuildConnectionURI_WithoutDirectConnection_OmitsParam(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "localhost",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: false,
}
uri := model.buildConnectionURI("testpass123")
assert.NotContains(t, uri, "directConnection")
}
func Test_BuildMongodumpURI_WithDirectConnection_ReturnsCorrectUri(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.BuildMongodumpURI("testpass123")
assert.Contains(t, uri, "mongodb://")
assert.Contains(t, uri, "directConnection=true")
assert.NotContains(t, uri, "/mydb")
}
func Test_BuildConnectionURI_WithDirectConnectionAndTls_ReturnsBothParams(t *testing.T) {
port := 27017
model := &MongodbDatabase{
Host: "mongo.example.local",
Port: &port,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: true,
IsSrv: false,
IsDirectConnection: true,
}
uri := model.buildConnectionURI("testpass123")
assert.Contains(t, uri, "directConnection=true")
assert.Contains(t, uri, "tls=true")
assert.Contains(t, uri, "tlsInsecure=true")
}
func Test_Validate_StandardConnection_RequiresPort(t *testing.T) {
model := &MongodbDatabase{
Host: "localhost",
Port: nil,
Username: "testuser",
Password: "testpass123",
Database: "mydb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
err := model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "port is required for standard connections")
}

View File

@@ -400,6 +400,7 @@ func HasPrivilege(privileges, priv string) bool {
func (m *MysqlDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
allowCleartext := ""
if m.IsHttps {
err := mysql.RegisterTLSConfig("mysql-skip-verify", &tls.Config{
@@ -411,16 +412,18 @@ func (m *MysqlDatabase) buildDSN(password string, database string) string {
}
tlsConfig = "mysql-skip-verify"
allowCleartext = "&allowCleartextPasswords=1"
}
return fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4%s",
m.Username,
password,
m.Host,
m.Port,
database,
tlsConfig,
allowCleartext,
)
}
@@ -486,9 +489,13 @@ func detectPrivileges(ctx context.Context, db *sql.DB, database string) (string,
hasProcess := false
hasAllPrivileges := false
// Escape underscores to match MySQL's grant output format
// MySQL escapes _ as \_ in SHOW GRANTS output
// Pattern matches either literal _ or escaped \_
escapedDbName := strings.ReplaceAll(regexp.QuoteMeta(database), "_", `(_|\\_)`)
dbPatternStr := fmt.Sprintf(
`(?i)ON\s+[\x60'"]?%s[\x60'"]?\s*\.\s*\*`,
regexp.QuoteMeta(database),
escapedDbName,
)
dbPattern := regexp.MustCompile(dbPatternStr)
globalPattern := regexp.MustCompile(`(?i)ON\s+\*\s*\.\s*\*`)

View File

@@ -674,6 +674,112 @@ func Test_TestConnection_DatabaseWithUnderscores_Success(t *testing.T) {
assert.NoError(t, err)
}
func Test_TestConnection_DatabaseWithUnderscoresAndAllPrivileges_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version tools.MysqlVersion
port string
}{
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
{"MySQL 9", tools.MysqlVersion9, env.TestMysql90Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToMysqlContainer(t, tc.port, tc.version)
defer container.DB.Close()
underscoreDbName := "test_all_db"
_, err := container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", underscoreDbName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", underscoreDbName),
)
}()
underscoreDSN := fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
container.Username,
container.Password,
container.Host,
container.Port,
underscoreDbName,
)
underscoreDB, err := sqlx.Connect("mysql", underscoreDSN)
assert.NoError(t, err)
defer underscoreDB.Close()
_, err = underscoreDB.Exec(`
CREATE TABLE all_priv_test (
id INT AUTO_INCREMENT PRIMARY KEY,
data VARCHAR(255) NOT NULL
)
`)
assert.NoError(t, err)
_, err = underscoreDB.Exec(`INSERT INTO all_priv_test (data) VALUES ('test1')`)
assert.NoError(t, err)
allPrivUsername := fmt.Sprintf("allpriv_%s", uuid.New().String()[:8])
allPrivPassword := "allprivpass123"
_, err = underscoreDB.Exec(fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
allPrivUsername,
allPrivPassword,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec(fmt.Sprintf(
"GRANT ALL PRIVILEGES ON `%s`.* TO '%s'@'%%'",
underscoreDbName,
allPrivUsername,
))
assert.NoError(t, err)
_, err = underscoreDB.Exec("FLUSH PRIVILEGES")
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", allPrivUsername),
)
}()
mysqlModel := &MysqlDatabase{
Version: tc.version,
Host: container.Host,
Port: container.Port,
Username: allPrivUsername,
Password: allPrivPassword,
Database: &underscoreDbName,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
err = mysqlModel.TestConnection(logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, mysqlModel.Privileges)
assert.Contains(t, mysqlModel.Privileges, "SELECT")
assert.Contains(t, mysqlModel.Privileges, "SHOW VIEW")
})
}
}
type MysqlContainer struct {
Host string
Port int
@@ -694,7 +800,7 @@ func connectToMysqlContainer(
}
dbName := "testdb"
host := "127.0.0.1"
host := config.GetEnv().TestLocalhost
username := "root"
password := "rootpassword"

View File

@@ -13,6 +13,7 @@ import (
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"gorm.io/gorm"
)
@@ -85,6 +86,42 @@ func (p *PostgresqlDatabase) Validate() error {
return errors.New("cpu count must be greater than 0")
}
// Prevent Databasus from backing up itself
// Databasus runs an internal PostgreSQL instance that should not be backed up through the UI
// because it would expose internal metadata to non-system administrators.
// To properly backup Databasus, see: https://databasus.com/faq#backup-databasus
if p.Database != nil && *p.Database != "" {
localhostHosts := []string{
"localhost",
"127.0.0.1",
"172.17.0.1",
"host.docker.internal",
"::1", // IPv6 loopback (equivalent to 127.0.0.1)
"::", // IPv6 all interfaces (equivalent to 0.0.0.0)
"0.0.0.0", // IPv4 all interfaces
}
isLocalhost := false
for _, host := range localhostHosts {
if strings.EqualFold(p.Host, host) {
isLocalhost = true
break
}
}
// Also check if the host is in the entire 127.0.0.0/8 loopback range
if strings.HasPrefix(p.Host, "127.") {
isLocalhost = true
}
if isLocalhost && strings.EqualFold(*p.Database, "databasus") {
return errors.New(
"backing up Databasus internal database is not allowed. To backup Databasus itself, see https://databasus.com/faq#backup-databasus",
)
}
}
return nil
}
@@ -358,10 +395,13 @@ func (p *PostgresqlDatabase) IsUserReadOnly(
//
// This method performs the following operations atomically in a single transaction:
// 1. Creates a PostgreSQL user with a UUID-based password
// 2. Grants CONNECT privilege on the database
// 3. Grants USAGE on all non-system schemas
// 4. Grants SELECT on all existing tables and sequences
// 5. Sets default privileges for future tables and sequences
// 2. Revokes CREATE privilege on public schema from PUBLIC role
// 3. Grants CONNECT privilege on the database
// 4. Discovers all user-created schemas
// 5. Grants USAGE on all non-system schemas
// 6. Grants SELECT on all existing tables and sequences
// 7. Sets default privileges for future tables and sequences
// 8. Verifies user creation before committing
//
// Security features:
// - Username format: "databasus-{8-char-uuid}" for uniqueness
@@ -451,33 +491,56 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("failed to create user: %w", err)
}
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
// Step 2: Check if public schema exists and revoke CREATE privilege if it does
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
// public schema through the PUBLIC role. This is a one-time operation that affects
// the entire database, making it more secure by default.
// Note: This only affects the public schema; other schemas are unaffected.
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
!strings.Contains(err.Error(), "permission denied") {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
if err != nil {
logger.Error(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
var publicSchemaExists bool
err = tx.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM information_schema.schemata
WHERE schema_name = 'public'
)
`).Scan(&publicSchemaExists)
if err != nil {
return "", "", fmt.Errorf("failed to check if public schema exists: %w", err)
}
// Step 2: Grant database connection privilege and revoke TEMP
if publicSchemaExists {
// Revoke CREATE from PUBLIC role (affects all users)
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
if err != nil {
if strings.Contains(err.Error(), "permission denied") {
logger.Warn(
"Failed to revoke CREATE on public from PUBLIC (permission denied)",
"error",
err,
)
} else {
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC on existing public schema: %w", err)
}
}
// Now revoke from the specific user as well (belt and suspenders)
_, err = tx.Exec(
ctx,
fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername),
)
if err != nil {
logger.Warn(
"Failed to revoke CREATE on public schema from user",
"error",
err,
"username",
baseUsername,
)
}
} else {
logger.Info("Public schema does not exist, skipping CREATE privilege revocation")
}
// Step 3: Grant database connection privilege and revoke TEMP
_, err = tx.Exec(
ctx,
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
@@ -501,12 +564,23 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
}
// Step 3: Discover all user-created schemas
rows, err := tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
// Step 4: Discover schemas to grant privileges on
// If IncludeSchemas is specified, only use those schemas; otherwise use all non-system schemas
var rows pgx.Rows
if len(p.IncludeSchemas) > 0 {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
AND schema_name = ANY($1::text[])
`, p.IncludeSchemas)
} else {
rows, err = tx.Query(ctx, `
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
`)
}
if err != nil {
return "", "", fmt.Errorf("failed to get schemas: %w", err)
}
@@ -526,7 +600,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
return "", "", fmt.Errorf("error iterating schemas: %w", err)
}
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
// Step 5: Grant USAGE on each schema and explicitly prevent CREATE
for _, schema := range schemas {
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
_, err = tx.Exec(
@@ -555,51 +629,198 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
}
}
// Step 5: Grant SELECT on ALL existing tables and sequences
grantSelectSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Step 6: Grant SELECT on ALL existing tables and sequences
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL TABLES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, grantSelectSQL)
if err != nil {
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`GRANT SELECT ON ALL SEQUENCES IN SCHEMA "%s" TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to grant select on sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 6: Set default privileges for FUTURE tables and sequences
defaultPrivilegesSQL := fmt.Sprintf(`
DO $$
DECLARE
schema_rec RECORD;
BEGIN
FOR schema_rec IN
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
LOOP
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
END LOOP;
END $$;
`, baseUsername, baseUsername)
// Step 7: Set default privileges for FUTURE tables and sequences
// First, set default privileges for objects created by the current user
// Use the already-filtered schemas list from Step 4
for _, schema := range schemas {
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for tables in schema %s: %w",
schema,
err,
)
}
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
if err != nil {
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
schema,
baseUsername,
),
)
if err != nil {
return "", "", fmt.Errorf(
"failed to set default privileges for sequences in schema %s: %w",
schema,
err,
)
}
}
// Step 7: Verify user creation before committing
// Step 8: Discover all roles that own objects in each schema
// This is needed because ALTER DEFAULT PRIVILEGES only applies to objects created by the current role.
// To handle tables created by OTHER users (like the GitHub issue with partitioned tables),
// we need to set "ALTER DEFAULT PRIVILEGES FOR ROLE <owner>" for each object owner.
// Filter by IncludeSchemas if specified.
type SchemaOwner struct {
SchemaName string
RoleName string
}
var ownerRows pgx.Rows
if len(p.IncludeSchemas) > 0 {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND n.nspname = ANY($1::text[])
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`, p.IncludeSchemas)
} else {
ownerRows, err = tx.Query(ctx, `
SELECT DISTINCT n.nspname as schema_name, pg_get_userbyid(c.relowner) as role_name
FROM pg_class c
JOIN pg_namespace n ON c.relnamespace = n.oid
WHERE n.nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
AND c.relkind IN ('r', 'p', 'v', 'm', 'f')
AND pg_get_userbyid(c.relowner) != current_user
ORDER BY n.nspname, role_name
`)
}
if err != nil {
// Log warning but continue - this is a best-effort enhancement
logger.Warn("Failed to query object owners for default privileges", "error", err)
} else {
var schemaOwners []SchemaOwner
for ownerRows.Next() {
var so SchemaOwner
if err := ownerRows.Scan(&so.SchemaName, &so.RoleName); err != nil {
ownerRows.Close()
logger.Warn("Failed to scan schema owner", "error", err)
break
}
schemaOwners = append(schemaOwners, so)
}
ownerRows.Close()
if err := ownerRows.Err(); err != nil {
logger.Warn("Error iterating schema owners", "error", err)
}
// Step 9: Set default privileges FOR ROLE for each object owner
// Note: This may fail for some roles due to permission issues (e.g., roles owned by other superusers)
// We log warnings but continue - user creation should succeed even if some roles can't be configured
for _, so := range schemaOwners {
// Try to set default privileges for tables
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON TABLES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (tables)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
// Try to set default privileges for sequences
_, err = tx.Exec(
ctx,
fmt.Sprintf(
`ALTER DEFAULT PRIVILEGES FOR ROLE "%s" IN SCHEMA "%s" GRANT SELECT ON SEQUENCES TO "%s"`,
so.RoleName,
so.SchemaName,
baseUsername,
),
)
if err != nil {
logger.Warn(
"Failed to set default privileges for role (sequences)",
"error",
err,
"role",
so.RoleName,
"schema",
so.SchemaName,
"readonly_user",
baseUsername,
)
}
}
if len(schemaOwners) > 0 {
logger.Info(
"Set default privileges for existing object owners",
"readonly_user",
baseUsername,
"owner_count",
len(schemaOwners),
)
}
}
// Step 10: Verify user creation before committing
var verifyUsername string
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
Scan(&verifyUsername)
@@ -815,7 +1036,15 @@ func checkBackupPermissions(
}
if err != nil {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
// If the user doesn't have USAGE on the schema, has_table_privilege will fail
// with "permission denied for schema". This means they definitely don't have
// SELECT privileges, so treat this as missing permissions rather than an error.
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == "42501" { // insufficient_privilege
selectableTableCount = 0
} else {
return fmt.Errorf("cannot check SELECT privileges: %w", err)
}
}
if selectableTableCount == 0 {
missingPrivileges = append(missingPrivileges, "SELECT on tables")

View File

@@ -599,6 +599,10 @@ func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
}
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
if config.GetEnv().IsSkipExternalResourcesTests {
t.Skip("Skipping Supabase test: IS_SKIP_EXTERNAL_RESOURCES_TESTS is true")
}
env := config.GetEnv()
if env.TestSupabaseHost == "" {
@@ -705,6 +709,607 @@ func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
assert.Contains(t, err.Error(), "permission denied")
}
func Test_CreateReadOnlyUser_WithPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP TABLE IF EXISTS public_schema_test CASCADE;
CREATE TABLE public_schema_test (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public_schema_test (data) VALUES ('test1'), ('test2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err)
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var count int
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM public_schema_test")
assert.NoError(t, err)
assert.Equal(t, 2, count)
_, err = readOnlyConn.Exec(
"INSERT INTO public_schema_test (data) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_WithoutPublicSchema_Success(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS public CASCADE;
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA app_schema;
CREATE SCHEMA data_schema;
CREATE TABLE app_schema.users (
id SERIAL PRIMARY KEY,
username TEXT NOT NULL
);
CREATE TABLE data_schema.records (
id SERIAL PRIMARY KEY,
info TEXT NOT NULL
);
INSERT INTO app_schema.users (username) VALUES ('user1'), ('user2');
INSERT INTO data_schema.records (info) VALUES ('record1'), ('record2');
`)
assert.NoError(t, err)
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.NoError(t, err, "CreateReadOnlyUser should succeed without public schema")
assert.NotEmpty(t, username)
assert.NotEmpty(t, password)
assert.True(t, strings.HasPrefix(username, "databasus-"))
readOnlyModel := &PostgresqlDatabase{
Version: pgModel.Version,
Host: pgModel.Host,
Port: pgModel.Port,
Username: username,
Password: password,
Database: pgModel.Database,
IsHttps: false,
}
isReadOnly, privileges, err := readOnlyModel.IsUserReadOnly(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.True(t, isReadOnly, "User should be read-only")
assert.Empty(t, privileges, "Read-only user should have no write privileges")
readOnlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
username,
password,
container.Database,
)
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
assert.NoError(t, err)
defer readOnlyConn.Close()
var userCount int
err = readOnlyConn.Get(&userCount, "SELECT COUNT(*) FROM app_schema.users")
assert.NoError(t, err)
assert.Equal(t, 2, userCount)
var recordCount int
err = readOnlyConn.Get(&recordCount, "SELECT COUNT(*) FROM data_schema.records")
assert.NoError(t, err)
assert.Equal(t, 2, recordCount)
_, err = readOnlyConn.Exec(
"INSERT INTO app_schema.users (username) VALUES ('should-fail')",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE app_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = readOnlyConn.Exec("CREATE TABLE data_schema.hack_table (id INT)")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
if err != nil {
t.Logf("Warning: Failed to drop owned objects: %v", err)
}
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
assert.NoError(t, err)
_, err = container.DB.Exec(`
DROP SCHEMA IF EXISTS app_schema CASCADE;
DROP SCHEMA IF EXISTS data_schema CASCADE;
CREATE SCHEMA IF NOT EXISTS public;
`)
assert.NoError(t, err)
})
}
}
func Test_CreateReadOnlyUser_PublicSchemaExistsButNoPermissions_ReturnsError(t *testing.T) {
env := config.GetEnv()
cases := []struct {
name string
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
{"PostgreSQL 16", "16", env.TestPostgres16Port},
{"PostgreSQL 17", "17", env.TestPostgres17Port},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
container := connectToPostgresContainer(t, tc.port)
defer container.DB.Close()
limitedAdminUsername := fmt.Sprintf("limited_admin_%s", uuid.New().String()[:8])
limitedAdminPassword := "limited_password_123"
_, err := container.DB.Exec(`
CREATE SCHEMA IF NOT EXISTS public;
DROP TABLE IF EXISTS public.permission_test_table CASCADE;
CREATE TABLE public.permission_test_table (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO public.permission_test_table (data) VALUES ('test1');
`)
assert.NoError(t, err)
_, err = container.DB.Exec(`GRANT CREATE ON SCHEMA public TO PUBLIC`)
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN CREATEROLE`,
limitedAdminUsername,
limitedAdminPassword,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
limitedAdminUsername,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, limitedAdminUsername),
)
_, _ = container.DB.Exec(
fmt.Sprintf(`DROP USER IF EXISTS "%s"`, limitedAdminUsername),
)
_, _ = container.DB.Exec(`REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
}()
limitedAdminDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
limitedAdminUsername,
limitedAdminPassword,
container.Database,
)
limitedAdminConn, err := sqlx.Connect("postgres", limitedAdminDSN)
assert.NoError(t, err)
defer limitedAdminConn.Close()
pgModel := &PostgresqlDatabase{
Version: tools.GetPostgresqlVersionEnum(tc.version),
Host: container.Host,
Port: container.Port,
Username: limitedAdminUsername,
Password: limitedAdminPassword,
Database: &container.Database,
IsHttps: false,
}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
assert.Error(
t,
err,
"CreateReadOnlyUser should fail when admin lacks permissions to secure public schema",
)
if err != nil {
errorMsg := err.Error()
hasExpectedError := strings.Contains(
errorMsg,
"failed to revoke CREATE from PUBLIC on existing public schema",
) ||
strings.Contains(errorMsg, "permission denied for schema public") ||
strings.Contains(errorMsg, "failed to grant")
assert.True(
t,
hasExpectedError,
"Error should indicate permission issues with public schema, got: %s",
errorMsg,
)
}
assert.Empty(t, username)
assert.Empty(t, password)
})
}
}
func Test_Validate_WhenLocalhostAndDatabasus_ReturnsError(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "localhost with databasus db",
host: "localhost",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.1 with databasus db",
host: "127.0.0.1",
username: "postgres",
database: "databasus",
},
{
name: "172.17.0.1 with databasus db",
host: "172.17.0.1",
username: "postgres",
database: "databasus",
},
{
name: "host.docker.internal with databasus db",
host: "host.docker.internal",
username: "postgres",
database: "databasus",
},
{
name: "LOCALHOST (uppercase) with DATABASUS db",
host: "LOCALHOST",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "LocalHost (mixed case) with DataBasus db",
host: "LocalHost",
username: "anyuser",
database: "DataBasus",
},
{
name: "localhost with databasus and any username",
host: "localhost",
username: "myuser",
database: "databasus",
},
{
name: "::1 (IPv6 loopback) with databasus db",
host: "::1",
username: "postgres",
database: "databasus",
},
{
name: ":: (IPv6 all interfaces) with databasus db",
host: "::",
username: "postgres",
database: "databasus",
},
{
name: "::1 (uppercase) with DATABASUS db",
host: "::1",
username: "POSTGRES",
database: "DATABASUS",
},
{
name: "0.0.0.0 (all IPv4 interfaces) with databasus db",
host: "0.0.0.0",
username: "postgres",
database: "databasus",
},
{
name: "127.0.0.2 (loopback range) with databasus db",
host: "127.0.0.2",
username: "postgres",
database: "databasus",
},
{
name: "127.255.255.255 (end of loopback range) with databasus db",
host: "127.255.255.255",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5437,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "backing up Databasus internal database is not allowed")
assert.Contains(t, err.Error(), "https://databasus.com/faq#backup-databasus")
})
}
}
func Test_Validate_WhenNotLocalhostOrNotDatabasus_ValidatesSuccessfully(t *testing.T) {
testCases := []struct {
name string
host string
username string
database string
}{
{
name: "different host (remote server) with databasus db",
host: "192.168.1.100",
username: "postgres",
database: "databasus",
},
{
name: "different database name on localhost",
host: "localhost",
username: "postgres",
database: "myapp",
},
{
name: "all different",
host: "db.example.com",
username: "appuser",
database: "production",
},
{
name: "localhost with postgres database",
host: "localhost",
username: "postgres",
database: "postgres",
},
{
name: "remote host with databasus db name (allowed)",
host: "db.example.com",
username: "postgres",
database: "databasus",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: tc.host,
Port: 5432,
Username: tc.username,
Password: "somepassword",
Database: &tc.database,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
})
}
}
func Test_Validate_WhenDatabaseIsNil_ValidatesSuccessfully(t *testing.T) {
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: nil,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenDatabaseIsEmpty_ValidatesSuccessfully(t *testing.T) {
emptyDb := ""
pgModel := &PostgresqlDatabase{
Host: "localhost",
Port: 5437,
Username: "postgres",
Password: "somepassword",
Database: &emptyDb,
CpuCount: 1,
}
err := pgModel.Validate()
assert.NoError(t, err)
}
func Test_Validate_WhenRequiredFieldsMissing_ReturnsError(t *testing.T) {
testCases := []struct {
name string
model *PostgresqlDatabase
expectedError string
}{
{
name: "missing host",
model: &PostgresqlDatabase{
Host: "",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "host is required",
},
{
name: "missing port",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 0,
Username: "user",
Password: "pass",
CpuCount: 1,
},
expectedError: "port is required",
},
{
name: "missing username",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "",
Password: "pass",
CpuCount: 1,
},
expectedError: "username is required",
},
{
name: "missing password",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "",
CpuCount: 1,
},
expectedError: "password is required",
},
{
name: "invalid cpu count",
model: &PostgresqlDatabase{
Host: "localhost",
Port: 5432,
Username: "user",
Password: "pass",
CpuCount: 0,
},
expectedError: "cpu count must be greater than 0",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := tc.model.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedError)
})
}
}
type PostgresContainer struct {
Host string
Port int
@@ -714,11 +1319,351 @@ type PostgresContainer struct {
DB *sqlx.DB
}
func Test_CreateReadOnlyUser_TablesCreatedByDifferentUser_ReadOnlyUserCanRead(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create a second database user who will create tables
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err := container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Step 2: Grant the user_creator privileges to connect and create tables
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CREATE ON SCHEMA public TO "%s"`,
userCreatorUsername,
))
assert.NoError(t, err)
// Step 2b: Create an initial table by user_creator so they become an object owner
// This is important because our fix discovers existing object owners
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
initialTableName := fmt.Sprintf(
"public.initial_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('initial_data');
`, initialTableName, initialTableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, initialTableName))
}()
// Step 3: NOW create read-only user via Databasus (as admin)
// At this point, user_creator already owns objects, so ALTER DEFAULT PRIVILEGES FOR ROLE should apply
pgModel := createPostgresModel(container)
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: user_creator creates a NEW table AFTER the read-only user was created
// This table should automatically grant SELECT to the read-only user via ALTER DEFAULT PRIVILEGES FOR ROLE
tableName := fmt.Sprintf(
"public.future_table_%s",
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
)
_, err = userCreatorConn.Exec(fmt.Sprintf(`
CREATE TABLE %s (
id SERIAL PRIMARY KEY,
data TEXT NOT NULL
);
INSERT INTO %s (data) VALUES ('test_data_1'), ('test_data_2');
`, tableName, tableName))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS %s CASCADE`, tableName))
}()
// Step 5: Connect as read-only user and verify it can SELECT from the new table
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
var count int
err = readonlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName))
assert.NoError(t, err)
assert.Equal(
t,
2,
count,
"Read-only user should be able to SELECT from table created by different user",
)
// Step 6: Verify read-only user cannot write to the table
_, err = readonlyConn.Exec(
fmt.Sprintf("INSERT INTO %s (data) VALUES ('should-fail')", tableName),
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify pg_dump operations (LOCK TABLE) work
// pg_dump needs to lock tables in ACCESS SHARE MODE for consistent backup
tx, err := readonlyConn.Begin()
assert.NoError(t, err)
defer tx.Rollback()
_, err = tx.Exec(fmt.Sprintf("LOCK TABLE %s IN ACCESS SHARE MODE", tableName))
assert.NoError(t, err, "Read-only user should be able to LOCK TABLE (needed for pg_dump)")
err = tx.Commit()
assert.NoError(t, err)
}
func Test_CreateReadOnlyUser_WithIncludeSchemas_OnlyGrantsAccessToSpecifiedSchemas(t *testing.T) {
env := config.GetEnv()
container := connectToPostgresContainer(t, env.TestPostgres16Port)
defer container.DB.Close()
// Step 1: Create multiple schemas and tables
_, err := container.DB.Exec(`
DROP SCHEMA IF EXISTS included_schema CASCADE;
DROP SCHEMA IF EXISTS excluded_schema CASCADE;
CREATE SCHEMA included_schema;
CREATE SCHEMA excluded_schema;
CREATE TABLE public.public_table (id INT, data TEXT);
INSERT INTO public.public_table VALUES (1, 'public_data');
CREATE TABLE included_schema.included_table (id INT, data TEXT);
INSERT INTO included_schema.included_table VALUES (2, 'included_data');
CREATE TABLE excluded_schema.excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.excluded_table VALUES (3, 'excluded_data');
`)
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS included_schema CASCADE`)
_, _ = container.DB.Exec(`DROP SCHEMA IF EXISTS excluded_schema CASCADE`)
}()
// Step 2: Create a second user who owns tables in both included and excluded schemas
userCreatorUsername := fmt.Sprintf("user_creator_%s", uuid.New().String()[:8])
userCreatorPassword := "creator_password_123"
_, err = container.DB.Exec(fmt.Sprintf(
`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`,
userCreatorUsername,
userCreatorPassword,
))
assert.NoError(t, err)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, userCreatorUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, userCreatorUsername))
}()
// Grant privileges to user_creator
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT CONNECT ON DATABASE "%s" TO "%s"`,
container.Database,
userCreatorUsername,
))
assert.NoError(t, err)
for _, schema := range []string{"public", "included_schema", "excluded_schema"} {
_, err = container.DB.Exec(fmt.Sprintf(
`GRANT USAGE, CREATE ON SCHEMA %s TO "%s"`,
schema,
userCreatorUsername,
))
assert.NoError(t, err)
}
// User_creator creates tables in included and excluded schemas
userCreatorDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
userCreatorUsername,
userCreatorPassword,
container.Database,
)
userCreatorConn, err := sqlx.Connect("postgres", userCreatorDSN)
assert.NoError(t, err)
defer userCreatorConn.Close()
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.user_table (id INT, data TEXT);
INSERT INTO included_schema.user_table VALUES (4, 'user_included_data');
CREATE TABLE excluded_schema.user_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.user_excluded_table VALUES (5, 'user_excluded_data');
`)
assert.NoError(t, err)
// Step 3: Create read-only user with IncludeSchemas = ["public", "included_schema"]
pgModel := createPostgresModel(container)
pgModel.IncludeSchemas = []string{"public", "included_schema"}
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
ctx := context.Background()
readonlyUsername, readonlyPassword, err := pgModel.CreateReadOnlyUser(
ctx,
logger,
nil,
uuid.New(),
)
assert.NoError(t, err)
assert.NotEmpty(t, readonlyUsername)
assert.NotEmpty(t, readonlyPassword)
defer func() {
_, _ = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, readonlyUsername))
_, _ = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, readonlyUsername))
}()
// Step 4: Connect as read-only user
readonlyDSN := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
container.Host,
container.Port,
readonlyUsername,
readonlyPassword,
container.Database,
)
readonlyConn, err := sqlx.Connect("postgres", readonlyDSN)
assert.NoError(t, err)
defer readonlyConn.Close()
// Step 5: Verify read-only user CAN access included schemas
var publicData string
err = readonlyConn.Get(&publicData, "SELECT data FROM public.public_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "public_data", publicData)
var includedData string
err = readonlyConn.Get(&includedData, "SELECT data FROM included_schema.included_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "included_data", includedData)
var userIncludedData string
err = readonlyConn.Get(&userIncludedData, "SELECT data FROM included_schema.user_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(t, "user_included_data", userIncludedData)
// Step 6: Verify read-only user CANNOT access excluded schema
var excludedData string
err = readonlyConn.Get(&excludedData, "SELECT data FROM excluded_schema.excluded_table LIMIT 1")
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
err = readonlyConn.Get(
&excludedData,
"SELECT data FROM excluded_schema.user_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(t, err.Error(), "permission denied")
// Step 7: Verify future tables in included schemas are accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE included_schema.future_table (id INT, data TEXT);
INSERT INTO included_schema.future_table VALUES (6, 'future_data');
`)
assert.NoError(t, err)
var futureData string
err = readonlyConn.Get(&futureData, "SELECT data FROM included_schema.future_table LIMIT 1")
assert.NoError(t, err)
assert.Equal(
t,
"future_data",
futureData,
"Read-only user should access future tables in included schemas via ALTER DEFAULT PRIVILEGES FOR ROLE",
)
// Step 8: Verify future tables in excluded schema are NOT accessible
_, err = userCreatorConn.Exec(`
CREATE TABLE excluded_schema.future_excluded_table (id INT, data TEXT);
INSERT INTO excluded_schema.future_excluded_table VALUES (7, 'future_excluded_data');
`)
assert.NoError(t, err)
var futureExcludedData string
err = readonlyConn.Get(
&futureExcludedData,
"SELECT data FROM excluded_schema.future_excluded_table LIMIT 1",
)
assert.Error(t, err)
assert.Contains(
t,
err.Error(),
"permission denied",
"Read-only user should NOT access tables in excluded schemas",
)
}
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
dbName := "testdb"
password := "testpassword"
username := "testuser"
host := "localhost"
host := config.GetEnv().TestLocalhost
portInt, err := strconv.Atoi(port)
assert.NoError(t, err)

View File

@@ -1,6 +1,9 @@
package databases
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/notifiers"
users_services "databasus-backend/internal/features/users/services"
@@ -37,7 +40,22 @@ func GetDatabaseController() *DatabaseController {
return databaseController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
notifiers.GetNotifierService().SetNotifierDatabaseCounter(databaseService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,7 @@
package databases
import (
"context"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
"databasus-backend/internal/features/databases/databases/mysql"
@@ -84,6 +85,25 @@ func (d *Database) TestConnection(
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) (bool, []string, error) {
switch d.Type {
case DatabaseTypePostgres:
return d.Postgresql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMysql:
return d.Mysql.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMariadb:
return d.Mariadb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
case DatabaseTypeMongodb:
return d.Mongodb.IsUserReadOnly(ctx, logger, encryptor, d.ID)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"time"
"databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases/databases/mariadb"
"databasus-backend/internal/features/databases/databases/mongodb"
@@ -86,6 +87,23 @@ func (s *DatabaseService) CreateDatabase(
return nil, fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := database.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
if err != nil {
return nil, fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return nil, fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -153,6 +171,29 @@ func (s *DatabaseService) UpdateDatabase(
return fmt.Errorf("failed to auto-detect database data: %w", err)
}
if config.GetEnv().IsCloud {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
isReadOnly, permissions, err := existingDatabase.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
)
if err != nil {
return fmt.Errorf("failed to verify user permissions: %w", err)
}
if !isReadOnly {
return fmt.Errorf(
"in cloud mode, only read-only database users are allowed (user has permissions: %v)",
permissions,
)
}
}
oldName := existingDatabase.Name
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
@@ -162,11 +203,23 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
if oldName != existingDatabase.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database updated and renamed from '%s' to '%s'",
oldName,
existingDatabase.Name,
),
&user.ID,
existingDatabase.WorkspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
}
return nil
}
@@ -532,9 +585,19 @@ func (s *DatabaseService) TransferDatabaseToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(*sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database transferred: %s from workspace %s to workspace %s",
database.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Database transferred: %s from workspace '%s' to workspace '%s'",
database.Name, sourceWorkspace.Name, targetWorkspace.Name),
nil,
&targetWorkspaceID,
)
@@ -649,38 +712,7 @@ func (s *DatabaseService) IsUserReadOnly(
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
switch usingDatabase.Type {
case DatabaseTypePostgres:
return usingDatabase.Postgresql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMysql:
return usingDatabase.Mysql.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMariadb:
return usingDatabase.Mariadb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
case DatabaseTypeMongodb:
return usingDatabase.Mongodb.IsUserReadOnly(
ctx,
s.logger,
s.fieldEncryptor,
usingDatabase.ID,
)
default:
return false, nil, errors.New("read-only check not supported for this database type")
}
return usingDatabase.IsUserReadOnly(ctx, s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) CreateReadOnlyUser(

View File

@@ -10,6 +10,7 @@ import (
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/notifiers"
"databasus-backend/internal/features/storages"
"databasus-backend/internal/storage"
"databasus-backend/internal/util/tools"
"github.com/google/uuid"
@@ -25,7 +26,7 @@ func GetTestPostgresConfig() *postgresql.PostgresqlDatabase {
testDbName := "testdb"
return &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -48,7 +49,7 @@ func GetTestMariadbConfig() *mariadb.MariadbDatabase {
testDbName := "testdb"
return &mariadb.MariadbDatabase{
Version: tools.MariadbVersion1011,
Host: "localhost",
Host: config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -69,13 +70,14 @@ func GetTestMongodbConfig() *mongodb.MongodbDatabase {
return &mongodb.MongodbDatabase{
Version: tools.MongodbVersion7,
Host: "localhost",
Port: port,
Host: config.GetEnv().TestLocalhost,
Port: &port,
Username: "root",
Password: "rootpassword",
Database: "testdb",
AuthDatabase: "admin",
IsHttps: false,
IsSrv: false,
CpuCount: 1,
}
}
@@ -104,6 +106,19 @@ func CreateTestDatabase(
}
func RemoveTestDatabase(database *Database) {
// Delete backups and backup configs associated with this database
// We hardcode SQL here because we cannot call backups feature due to DI inversion
// (databases package cannot import backups package as backups already imports databases)
db := storage.GetDb()
if err := db.Exec("DELETE FROM backups WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backups: %v", err))
}
if err := db.Exec("DELETE FROM backup_configs WHERE database_id = ?", database.ID).Error; err != nil {
panic(fmt.Sprintf("failed to delete backup config: %v", err))
}
err := databaseRepository.Delete(database.ID)
if err != nil {
panic(err)

View File

@@ -12,6 +12,15 @@ import (
type DiskService struct{}
func (s *DiskService) GetDiskUsage() (*DiskUsage, error) {
if config.GetEnv().IsCloud {
return &DiskUsage{
Platform: PlatformLinux,
TotalSpaceBytes: 100,
UsedSpaceBytes: 0,
FreeSpaceBytes: 100,
}, nil
}
platform := s.detectPlatform()
var path string

View File

@@ -0,0 +1,22 @@
package email
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/logger"
)
var env = config.GetEnv()
var log = logger.GetLogger()
var emailSMTPSender = &EmailSMTPSender{
log,
env.SMTPHost,
env.SMTPPort,
env.SMTPUser,
env.SMTPPassword,
env.SMTPHost != "" && env.SMTPPort != 0,
}
func GetEmailSMTPSender() *EmailSMTPSender {
return emailSMTPSender
}

View File

@@ -0,0 +1,245 @@
package email
import (
"crypto/tls"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"time"
)
const (
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
DefaultHelloName = "localhost"
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
)
type EmailSMTPSender struct {
logger *slog.Logger
smtpHost string
smtpPort int
smtpUser string
smtpPassword string
isConfigured bool
}
func (s *EmailSMTPSender) SendEmail(to, subject, body string) error {
if !s.isConfigured {
s.logger.Warn("Skipping email send, SMTP not initialized", "to", to, "subject", subject)
return nil
}
from := s.smtpUser
if from == "" {
from = "noreply@" + s.smtpHost
}
emailContent := s.buildEmailContent(to, subject, body, from)
isAuthRequired := s.smtpUser != "" && s.smtpPassword != ""
if s.smtpPort == ImplicitTLSPort {
return s.sendImplicitTLS(to, from, emailContent, isAuthRequired)
}
return s.sendStartTLS(to, from, emailContent, isAuthRequired)
}
func (s *EmailSMTPSender) buildEmailContent(to, subject, body, from string) []byte {
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
encodedSubject := encodeRFC2047(subject)
subjectHeader := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", to)
return []byte(fromHeader + toHeader + subjectHeader + dateHeader + mimeHeaders + body)
}
func (s *EmailSMTPSender) sendImplicitTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createImplicitTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) sendStartTLS(
to, from string,
emailContent []byte,
isAuthRequired bool,
) error {
createClient := func() (*smtp.Client, func(), error) {
return s.createStartTLSClient()
}
client, cleanup, err := s.authenticateWithRetry(createClient, isAuthRequired)
if err != nil {
return err
}
defer cleanup()
return s.sendEmail(client, to, from, emailContent)
}
func (s *EmailSMTPSender) createImplicitTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
tlsConfig := &tls.Config{ServerName: s.smtpHost}
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) createStartTLSClient() (*smtp.Client, func(), error) {
addr := net.JoinHostPort(s.smtpHost, fmt.Sprintf("%d", s.smtpPort))
dialer := &net.Dialer{Timeout: DefaultTimeout}
conn, err := dialer.Dial("tcp", addr)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect to SMTP server: %w", err)
}
client, err := smtp.NewClient(conn, s.smtpHost)
if err != nil {
_ = conn.Close()
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)
}
if ok, _ := client.Extension("STARTTLS"); ok {
if err := client.StartTLS(&tls.Config{ServerName: s.smtpHost}); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("STARTTLS failed: %w", err)
}
}
return client, func() { _ = client.Quit() }, nil
}
func (s *EmailSMTPSender) authenticateWithRetry(
createClient func() (*smtp.Client, func(), error),
isAuthRequired bool,
) (*smtp.Client, func(), error) {
client, cleanup, err := createClient()
if err != nil {
return nil, nil, err
}
if !isAuthRequired {
return client, cleanup, nil
}
// Try PLAIN auth first
plainAuth := smtp.PlainAuth("", s.smtpUser, s.smtpPassword, s.smtpHost)
if err := client.Auth(plainAuth); err == nil {
return client, cleanup, nil
}
// PLAIN auth failed, connection may be closed - recreate and try LOGIN auth
cleanup()
client, cleanup, err = createClient()
if err != nil {
return nil, nil, err
}
loginAuth := &loginAuth{username: s.smtpUser, password: s.smtpPassword}
if err := client.Auth(loginAuth); err != nil {
cleanup()
return nil, nil, fmt.Errorf("SMTP authentication failed: %w", err)
}
return client, cleanup, nil
}
func (s *EmailSMTPSender) sendEmail(client *smtp.Client, to, from string, content []byte) error {
if err := client.Mail(from); err != nil {
return fmt.Errorf("failed to set sender: %w", err)
}
if err := client.Rcpt(to); err != nil {
return fmt.Errorf("failed to set recipient: %w", err)
}
writer, err := client.Data()
if err != nil {
return fmt.Errorf("failed to get data writer: %w", err)
}
if _, err = writer.Write(content); err != nil {
return fmt.Errorf("failed to write email content: %w", err)
}
if err = writer.Close(); err != nil {
return fmt.Errorf("failed to close data writer: %w", err)
}
return nil
}
func encodeRFC2047(s string) string {
return mime.QEncoding.Encode("UTF-8", s)
}
type loginAuth struct {
username string
password string
}
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
return "LOGIN", []byte{}, nil
}
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if more {
switch string(fromServer) {
case "Username:", "User Name\x00":
return []byte(a.username), nil
case "Password:", "Password\x00":
return []byte(a.password), nil
default:
return []byte(a.username), nil
}
}
return nil, nil
}

View File

@@ -2,30 +2,47 @@ package healthcheck_attempt
import (
"context"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"fmt"
"log/slog"
"sync"
"sync/atomic"
"time"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
)
type HealthcheckAttemptBackgroundService struct {
healthcheckConfigService *healthcheck_config.HealthcheckConfigService
checkDatabaseHealthUseCase *CheckDatabaseHealthUseCase
logger *slog.Logger
runOnce sync.Once
hasRun atomic.Bool
}
func (s *HealthcheckAttemptBackgroundService) Run(ctx context.Context) {
// first healthcheck immediately
s.checkDatabases()
wasAlreadyRun := s.hasRun.Load()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
s.runOnce.Do(func() {
s.hasRun.Store(true)
// first healthcheck immediately
s.checkDatabases()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
s.checkDatabases()
}
}
})
if wasAlreadyRun {
panic(fmt.Sprintf("%T.Run() called multiple times", s))
}
}

View File

@@ -144,6 +144,10 @@ func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "forbidden")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -181,6 +185,10 @@ func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
for _, attempt := range response {
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
@@ -201,6 +209,10 @@ func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
)
assert.Equal(t, 0, len(response))
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -1,6 +1,9 @@
package healthcheck_attempt
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/databases"
healthcheck_config "databasus-backend/internal/features/healthcheck/config"
"databasus-backend/internal/features/notifiers"
@@ -22,9 +25,11 @@ var checkDatabaseHealthUseCase = &CheckDatabaseHealthUseCase{
}
var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase,
logger.GetLogger(),
healthcheckConfigService: healthcheck_config.GetHealthcheckConfigService(),
checkDatabaseHealthUseCase: checkDatabaseHealthUseCase,
logger: logger.GetLogger(),
runOnce: sync.Once{},
hasRun: atomic.Bool{},
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,

View File

@@ -130,6 +130,10 @@ func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -162,6 +166,10 @@ func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
@@ -268,6 +276,10 @@ func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
@@ -295,6 +307,10 @@ func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T)
assert.Equal(t, 1, response.IntervalMinutes)
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
assert.Equal(t, 7, response.StoreAttemptsDays)
// Cleanup
databases.RemoveTestDatabase(database)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func createTestDatabaseViaAPI(

View File

@@ -1,6 +1,9 @@
package healthcheck_config
import (
"sync"
"sync/atomic"
"databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/databases"
workspaces_services "databasus-backend/internal/features/workspaces/services"
@@ -27,8 +30,23 @@ func GetHealthcheckConfigController() *HealthcheckConfigController {
return healthcheckConfigController
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
databases.
GetDatabaseService().
AddDbCreationListener(healthcheckConfigService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -1,6 +1,9 @@
package notifiers
import (
"sync"
"sync/atomic"
audit_logs "databasus-backend/internal/features/audit_logs"
workspaces_services "databasus-backend/internal/features/workspaces/services"
"databasus-backend/internal/util/encryption"
@@ -32,6 +35,22 @@ func GetNotifierService() *NotifierService {
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
var (
setupOnce sync.Once
isSetup atomic.Bool
)
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
wasAlreadySetup := isSetup.Load()
setupOnce.Do(func() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
isSetup.Store(true)
})
if wasAlreadySetup {
logger.GetLogger().Warn("SetupDependencies called multiple times, ignoring subsequent call")
}
}

View File

@@ -6,19 +6,20 @@ import (
"errors"
"fmt"
"log/slog"
"mime"
"net"
"net/smtp"
"os"
"time"
"github.com/google/uuid"
)
const (
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
DefaultHelloName = "localhost"
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
ImplicitTLSPort = 465
DefaultTimeout = 5 * time.Second
MIMETypeHTML = "text/html"
MIMECharsetUTF8 = "UTF-8"
)
type EmailNotifier struct {
@@ -115,16 +116,46 @@ func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor
return nil
}
func getHelloName() string {
hostname, err := os.Hostname()
if err != nil || hostname == "" {
return "localhost"
}
return hostname
}
// encodeRFC2047 encodes a string using RFC 2047 MIME encoding for email headers
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
func encodeRFC2047(s string) string {
// mime.QEncoding handles UTF-8 → =?UTF-8?Q?...?= encoding
// This allows non-ASCII characters (emojis, accents, etc.) in email headers
// while maintaining compatibility with all SMTP servers
return mime.QEncoding.Encode("UTF-8", s)
}
func (e *EmailNotifier) buildEmailContent(heading, message, from string) []byte {
subject := fmt.Sprintf("Subject: %s\r\n", heading)
mime := fmt.Sprintf(
// Encode Subject header using RFC 2047 to avoid SMTPUTF8 requirement
// This ensures compatibility with SMTP servers that don't support SMTPUTF8
encodedSubject := encodeRFC2047(heading)
subject := fmt.Sprintf("Subject: %s\r\n", encodedSubject)
dateHeader := fmt.Sprintf("Date: %s\r\n", time.Now().UTC().Format(time.RFC1123Z))
messageID := fmt.Sprintf("Message-ID: <%s@%s>\r\n", uuid.New().String(), e.SMTPHost)
mimeHeaders := fmt.Sprintf(
"MIME-version: 1.0;\nContent-Type: %s; charset=\"%s\";\n\n",
MIMETypeHTML,
MIMECharsetUTF8,
)
fromHeader := fmt.Sprintf("From: %s\r\n", from)
// Encode From header display name if it contains non-ASCII
encodedFrom := encodeRFC2047(from)
fromHeader := fmt.Sprintf("From: %s\r\n", encodedFrom)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
return []byte(fromHeader + toHeader + subject + mime + message)
return []byte(fromHeader + toHeader + subject + dateHeader + messageID + mimeHeaders + message)
}
func (e *EmailNotifier) sendImplicitTLS(
@@ -199,7 +230,7 @@ func (e *EmailNotifier) createStartTLSClient() (*smtp.Client, func(), error) {
return nil, nil, fmt.Errorf("failed to create SMTP client: %w", err)
}
if err := client.Hello(DefaultHelloName); err != nil {
if err := client.Hello(getHelloName()); err != nil {
_ = client.Quit()
_ = conn.Close()
return nil, nil, fmt.Errorf("SMTP hello failed: %w", err)

View File

@@ -58,6 +58,8 @@ func (s *NotifierService) SaveNotifier(
return err
}
oldName := existingNotifier.Name
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
@@ -67,11 +69,23 @@ func (s *NotifierService) SaveNotifier(
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
if oldName != existingNotifier.Name {
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Notifier updated and renamed from '%s' to '%s'",
oldName,
existingNotifier.Name,
),
&user.ID,
&workspaceID,
)
} else {
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
}
} else {
notifier.WorkspaceID = workspaceID
@@ -343,9 +357,19 @@ func (s *NotifierService) TransferNotifierToWorkspace(
return err
}
sourceWorkspace, err := s.workspaceService.GetWorkspaceByID(sourceWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get source workspace: %w", err)
}
targetWorkspace, err := s.workspaceService.GetWorkspaceByID(targetWorkspaceID)
if err != nil {
return fmt.Errorf("failed to get target workspace: %w", err)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier transferred: %s from workspace %s to workspace %s",
existingNotifier.Name, sourceWorkspaceID, targetWorkspaceID),
fmt.Sprintf("Notifier transferred: %s from workspace '%s' to workspace '%s'",
existingNotifier.Name, sourceWorkspace.Name, targetWorkspace.Name),
&user.ID,
&targetWorkspaceID,
)

View File

@@ -0,0 +1,20 @@
package plans
import (
"databasus-backend/internal/util/logger"
)
var databasePlanRepository = &DatabasePlanRepository{}
var databasePlanService = &DatabasePlanService{
databasePlanRepository,
logger.GetLogger(),
}
func GetDatabasePlanService() *DatabasePlanService {
return databasePlanService
}
func GetDatabasePlanRepository() *DatabasePlanRepository {
return databasePlanRepository
}

View File

@@ -0,0 +1,19 @@
package plans
import (
"databasus-backend/internal/util/period"
"github.com/google/uuid"
)
type DatabasePlan struct {
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;primaryKey;not null"`
MaxBackupSizeMB int64 `json:"maxBackupSizeMb" gorm:"column:max_backup_size_mb;type:int;not null"`
MaxBackupsTotalSizeMB int64 `json:"maxBackupsTotalSizeMb" gorm:"column:max_backups_total_size_mb;type:int;not null"`
MaxStoragePeriod period.TimePeriod `json:"maxStoragePeriod" gorm:"column:max_storage_period;type:text;not null"`
}
func (p *DatabasePlan) TableName() string {
return "database_plans"
}

View File

@@ -0,0 +1,27 @@
package plans
import (
"databasus-backend/internal/storage"
"github.com/google/uuid"
)
type DatabasePlanRepository struct{}
func (r *DatabasePlanRepository) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
var databasePlan DatabasePlan
if err := storage.GetDb().Where("database_id = ?", databaseID).First(&databasePlan).Error; err != nil {
if err.Error() == "record not found" {
return nil, nil
}
return nil, err
}
return &databasePlan, nil
}
func (r *DatabasePlanRepository) CreateDatabasePlan(databasePlan *DatabasePlan) error {
return storage.GetDb().Create(&databasePlan).Error
}

View File

@@ -0,0 +1,67 @@
package plans
import (
"databasus-backend/internal/config"
"databasus-backend/internal/util/period"
"log/slog"
"github.com/google/uuid"
)
type DatabasePlanService struct {
databasePlanRepository *DatabasePlanRepository
logger *slog.Logger
}
func (s *DatabasePlanService) GetDatabasePlan(databaseID uuid.UUID) (*DatabasePlan, error) {
plan, err := s.databasePlanRepository.GetDatabasePlan(databaseID)
if err != nil {
return nil, err
}
if plan == nil {
s.logger.Info("no database plan found, creating default plan", "databaseID", databaseID)
defaultPlan := s.createDefaultDatabasePlan(databaseID)
err := s.databasePlanRepository.CreateDatabasePlan(defaultPlan)
if err != nil {
s.logger.Error("failed to create default database plan", "error", err)
return nil, err
}
return defaultPlan, nil
}
return plan, nil
}
func (s *DatabasePlanService) createDefaultDatabasePlan(databaseID uuid.UUID) *DatabasePlan {
var plan DatabasePlan
isCloud := config.GetEnv().IsCloud
if isCloud {
s.logger.Info("creating default database plan for cloud", "databaseID", databaseID)
// for playground we set limited storages enough to test,
// but not too expensive to provide it for Databasus
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 100, // ~ 1.5GB database
MaxBackupsTotalSizeMB: 4000, // ~ 30 daily backups + 10 manual backups
MaxStoragePeriod: period.PeriodWeek,
}
} else {
s.logger.Info("creating default database plan for self hosted", "databaseID", databaseID)
// by default - everything is unlimited in self hosted mode
plan = DatabasePlan{
DatabaseID: databaseID,
MaxBackupSizeMB: 0,
MaxBackupsTotalSizeMB: 0,
MaxStoragePeriod: period.PeriodForever,
}
}
return &plan
}

View File

@@ -1,38 +0,0 @@
package restores
import (
"context"
"databasus-backend/internal/features/restores/enums"
"log/slog"
)
type RestoreBackgroundService struct {
restoreRepository *RestoreRepository
logger *slog.Logger
}
func (s *RestoreBackgroundService) Run(ctx context.Context) {
if err := s.failRestoresInProgress(); err != nil {
s.logger.Error("Failed to fail restores in progress", "error", err)
panic(err)
}
}
func (s *RestoreBackgroundService) failRestoresInProgress() error {
restoresInProgress, err := s.restoreRepository.FindByStatus(enums.RestoreStatusInProgress)
if err != nil {
return err
}
for _, restore := range restoresInProgress {
failMessage := "Restore failed due to application restart"
restore.Status = enums.RestoreStatusFailed
restore.FailMessage = &failMessage
if err := s.restoreRepository.Save(restore); err != nil {
return err
}
}
return nil
}

View File

@@ -1,6 +1,7 @@
package restores
import (
restores_core "databasus-backend/internal/features/restores/core"
users_middleware "databasus-backend/internal/features/users/middleware"
"net/http"
@@ -15,6 +16,7 @@ type RestoreController struct {
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/restores/:backupId", c.GetRestores)
router.POST("/restores/:backupId/restore", c.RestoreBackup)
router.POST("/restores/cancel/:restoreId", c.CancelRestore)
}
// GetRestores
@@ -23,7 +25,7 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
// @Tags restores
// @Produce json
// @Param backupId path string true "Backup ID"
// @Success 200 {array} models.Restore
// @Success 200 {array} restores_core.Restore
// @Failure 400
// @Failure 401
// @Router /restores/{backupId} [get]
@@ -71,7 +73,7 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
return
}
var requestDTO RestoreBackupRequest
var requestDTO restores_core.RestoreBackupRequest
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -84,3 +86,33 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
ctx.JSON(http.StatusOK, gin.H{"message": "restore started successfully"})
}
// CancelRestore
// @Summary Cancel an in-progress restore
// @Description Cancel a restore that is currently in progress
// @Tags restores
// @Param restoreId path string true "Restore ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Router /restores/cancel/{restoreId} [post]
func (c *RestoreController) CancelRestore(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
restoreID, err := uuid.Parse(ctx.Param("restoreId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid restore ID"})
return
}
if err := c.restoreService.CancelRestore(user, restoreID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}

View File

@@ -16,7 +16,7 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"databasus-backend/internal/config"
env_config "databasus-backend/internal/config"
audit_logs "databasus-backend/internal/features/audit_logs"
"databasus-backend/internal/features/backups/backups"
backups_core "databasus-backend/internal/features/backups/backups/core"
@@ -24,16 +24,18 @@ import (
"databasus-backend/internal/features/databases"
"databasus-backend/internal/features/databases/databases/mysql"
"databasus-backend/internal/features/databases/databases/postgresql"
"databasus-backend/internal/features/restores/models"
"databasus-backend/internal/features/notifiers"
restores_core "databasus-backend/internal/features/restores/core"
"databasus-backend/internal/features/restores/restoring"
"databasus-backend/internal/features/storages"
local_storage "databasus-backend/internal/features/storages/models/local"
tasks_cancellation "databasus-backend/internal/features/tasks/cancellation"
users_dto "databasus-backend/internal/features/users/dto"
users_enums "databasus-backend/internal/features/users/enums"
users_services "databasus-backend/internal/features/users/services"
users_testing "databasus-backend/internal/features/users/testing"
workspaces_controllers "databasus-backend/internal/features/workspaces/controllers"
workspaces_models "databasus-backend/internal/features/workspaces/models"
workspaces_testing "databasus-backend/internal/features/workspaces/testing"
cache_utils "databasus-backend/internal/util/cache"
util_encryption "databasus-backend/internal/util/encryption"
test_utils "databasus-backend/internal/util/testing"
"databasus-backend/internal/util/tools"
@@ -43,10 +45,12 @@ func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -65,8 +69,10 @@ func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
@@ -85,12 +91,14 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
var restores []*models.Restore
var restores []*restores_core.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
@@ -105,15 +113,21 @@ func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -136,15 +150,17 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -165,15 +181,21 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -195,15 +217,21 @@ func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := RestoreBackupRequest{
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -233,7 +261,7 @@ func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Database restored from backup") &&
if strings.Contains(log.Message, "Database restored for database") &&
strings.Contains(log.Message, database.Name) {
found = true
break
@@ -272,18 +300,29 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
// Setup mock node for tests that skip disk validation and reach scheduler
if !tc.expectDiskValidated {
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
}
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
var database *databases.Database
var backup *backups_core.Backup
var request RestoreBackupRequest
var storage *storages.Storage
var request restores_core.RestoreBackupRequest
if tc.dbType == databases.DatabaseTypePostgres {
_, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
request = RestoreBackupRequest{
database, backup = createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request = restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
@@ -297,7 +336,16 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
owner.Token,
router,
)
storage := createTestStorage(workspace.ID)
database = mysqlDB
storage = createTestStorage(workspace.ID)
defer func() {
// Cleanup in dependency order: backup -> database -> storage
cleanupBackup(backup)
databases.RemoveTestDatabase(mysqlDB)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
}()
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(mysqlDB.ID)
@@ -309,11 +357,12 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup = createTestBackup(mysqlDB, owner)
request = RestoreBackupRequest{
backup = createTestBackup(mysqlDB, storage)
request = restores_core.RestoreBackupRequest{
MysqlDatabase: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: 3306,
Username: "root",
Password: "password",
@@ -353,16 +402,189 @@ func Test_RestoreBackup_DiskSpaceValidation(t *testing.T) {
}
}
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
func Test_CancelRestore_InProgressRestore_SuccessfullyCancelled(t *testing.T) {
cache_utils.ClearAllCache()
tasks_cancellation.SetupDependencies()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := createTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer func() {
backupRepo := backups_core.BackupRepository{}
backups, _ := backupRepo.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepo.DeleteByID(backup.ID)
}
restoreRepo := restores_core.RestoreRepository{}
restores, _ := restoreRepo.FindByStatus(restores_core.RestoreStatusInProgress)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
restores, _ = restoreRepo.FindByStatus(restores_core.RestoreStatusCanceled)
for _, restore := range restores {
restoreRepo.DeleteByID(restore.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspace(workspace, router)
cache_utils.ClearAllCache()
}()
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
backup := backups.CreateTestBackup(database.ID, storage.ID)
mockUsecase := &restoring.MockBlockingRestoreUsecase{
StartedChan: make(chan bool, 1),
}
restorerNode := restoring.CreateTestRestorerNodeWithUsecase(mockUsecase)
cancelNode := restoring.StartRestorerNodeForTest(t, restorerNode)
defer cancelNode()
time.Sleep(200 * time.Millisecond)
restoreRequest := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
var restoreResponse map[string]interface{}
test_utils.MakePostRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+user.Token,
restoreRequest,
http.StatusOK,
&restoreResponse,
)
return router
select {
case <-mockUsecase.StartedChan:
t.Log("Restore started and is blocking")
case <-time.After(2 * time.Second):
t.Fatal("Restore did not start within timeout")
}
restoreRepo := &restores_core.RestoreRepository{}
restores, err := restoreRepo.FindByBackupID(backup.ID)
assert.NoError(t, err)
assert.Greater(t, len(restores), 0, "At least one restore should exist")
var restoreID uuid.UUID
for _, r := range restores {
if r.Status == restores_core.RestoreStatusInProgress {
restoreID = r.ID
break
}
}
assert.NotEqual(t, uuid.Nil, restoreID, "Should find an in-progress restore")
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/cancel/%s", restoreID.String()),
"Bearer "+user.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
deadline := time.Now().UTC().Add(3 * time.Second)
var restore *restores_core.Restore
for time.Now().UTC().Before(deadline) {
restore, err = restoreRepo.FindByID(restoreID)
assert.NoError(t, err)
if restore.Status == restores_core.RestoreStatusCanceled {
break
}
time.Sleep(100 * time.Millisecond)
}
assert.Equal(t, restores_core.RestoreStatusCanceled, restore.Status)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Restore cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
time.Sleep(200 * time.Millisecond)
}
func Test_RestoreBackup_WithParallelRestoreInProgress_ReturnsError(t *testing.T) {
router := createTestRouter()
_, cleanup := SetupMockRestoreNode(t)
defer cleanup()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
defer workspaces_testing.RemoveTestWorkspace(workspace, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
defer cleanupDatabaseWithBackup(database, backup)
request := restores_core.RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: env_config.GetEnv().TestLocalhost,
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
testResp2 := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp2.Body), "another restore is already in progress")
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
func createTestDatabaseWithBackupForRestore(
@@ -387,7 +609,7 @@ func createTestDatabaseWithBackupForRestore(
panic(err)
}
backup := createTestBackup(database, owner)
backup := createTestBackup(database, storage)
return database, backup
}
@@ -433,7 +655,7 @@ func createTestMySQLDatabase(
token string,
router *gin.Engine,
) *databases.Database {
env := config.GetEnv()
env := env_config.GetEnv()
portStr := env.TestMysql80Port
if portStr == "" {
portStr = "33080"
@@ -451,7 +673,7 @@ func createTestMySQLDatabase(
Type: databases.DatabaseTypeMysql,
Mysql: &mysql.MysqlDatabase{
Version: tools.MysqlVersion80,
Host: "localhost",
Host: env_config.GetEnv().TestLocalhost,
Port: port,
Username: "testuser",
Password: "testpassword",
@@ -504,24 +726,14 @@ func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
storage *storages.Storage,
) *backups_core.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &backups_core.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
StorageID: storage.ID,
Status: backups_core.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
@@ -536,11 +748,11 @@ func createTestBackup(
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(
if err := storage.SaveFile(
context.Background(),
fieldEncryptor,
logger,
backup.ID,
backup.ID.String(),
reader,
); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
@@ -548,3 +760,22 @@ func createTestBackup(
return backup
}
func cleanupDatabaseWithBackup(database *databases.Database, backup *backups_core.Backup) {
// Clean up in reverse dependency order
cleanupBackup(backup)
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond)
// Clean up storage last (after database and backup are removed)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err == nil && config.StorageID != nil {
storages.RemoveTestStorage(*config.StorageID)
}
}
func cleanupBackup(backup *backups_core.Backup) {
repo := &backups_core.BackupRepository{}
repo.DeleteByID(backup.ID)
}

Some files were not shown because too many files have changed in this diff Show More