Compare commits

...

32 Commits

Author SHA1 Message Date
Rostislav Dugin
244a56d1bb FEATURE (secrets): Move secrets to the secret.key file instead of DB 2025-11-19 18:53:58 +03:00
Rostislav Dugin
95c833b619 FIX (backups): Fix passing encypted password to .pgpass 2025-11-19 17:10:19 +03:00
Rostislav Dugin
878fad5747 FEATURE (encryption): Add encyption for secrets in notifiers and storages 2025-11-18 21:23:59 +03:00
Rostislav Dugin
6ff3096695 FIX (password reset): Allow to change user password even if password was not set before 2025-11-17 20:20:31 +03:00
Rostislav Dugin
b4b514c2d5 FEATURE (encryption): Add backups encryption 2025-11-17 14:33:37 +03:00
Rostislav Dugin
da0fec6624 FEATURE (azure): Add Azure Blob Storage 2025-11-16 23:38:20 +03:00
Rostislav Dugin
408675023a FEATURE (s3): Add support of virtual-styled-domains and S3 prefix 2025-11-16 11:22:03 +03:00
Rostislav Dugin
0bc93389cc FEATURE (backups): Include workspace name in notification about success or fail 2025-11-15 11:40:42 +03:00
Rostislav Dugin
c8e6aea6e1 FEATURE (hints): Add hints about localhost connection 2025-11-15 00:25:51 +03:00
Rostislav Dugin
981ad21471 FEATURE (email): Add "to" header to email 2025-11-14 20:39:02 +03:00
Rostislav Dugin
177a9c782c Revert "FIX (notifiers): Improve email validation"
This reverts commit 02c735bc5a.
2025-11-14 20:35:22 +03:00
Rostislav Dugin
069d6bc8fe FEATURE (logo): Update logo 2025-11-14 20:19:26 +03:00
Rostislav Dugin
242d5543d4 FIX (backups): Avoid possibility of breaking DB on backup fail 2025-11-14 19:56:56 +03:00
Rostislav Dugin
02c735bc5a FIX (notifiers): Improve email validation 2025-11-14 18:02:27 +03:00
Rostislav Dugin
793b575146 FIX (storages): Ignore files removal errors for unavailable storage when deleting the database 2025-11-14 18:02:13 +03:00
Rostislav Dugin
a6e84b45f2 Merge pull request #84 from RostislavDugin/feature/add_pg_12
Feature/add pg 12
2025-11-12 15:43:09 +03:00
Rostislav Dugin
a941fbd093 FEATURE (postgres): Add PostgreSQL 12 tests and CI \ CD config 2025-11-12 15:39:44 +03:00
Rostislav Dugin
4492ba41f5 Merge pull request #82 from romanesko/feature/v12-support
feat: add PostgreSQL 12 support
2025-11-12 15:04:12 +03:00
Roman Bykovsky
3a5ac4b479 feat: add PostgreSQL 12 support 2025-11-11 18:53:26 +03:00
Rostislav Dugin
77aaabeaa1 FEATURE (docs): Update readme and docs links 2025-11-11 16:56:33 +03:00
Rostislav Dugin
01911dbf72 FIX (notifiers & storages): Avoid request for workspace_id for storages and notifiers removal 2025-11-11 10:05:45 +03:00
Rostislav Dugin
1a16f27a5d FIX (notifiers): Fix update of existing DB notifiers 2025-11-11 08:10:02 +03:00
Rostislav Dugin
778db71625 FIX (tests): Improve tests stability in CI \ CD 2025-11-09 20:41:36 +03:00
Rostislav Dugin
45fc9a7fff FIX (databases): Verify DB nil on side of DB instead of interface 2025-11-09 20:03:22 +03:00
Rostislav Dugin
7f5e786261 FIX (databases): If some DB missing PostgreSQL db fix nil issue 2025-11-09 18:57:42 +03:00
Rostislav Dugin
9b066bcb8a FEATURE (email): Add "from" field 2025-11-08 20:47:35 +03:00
Rostislav Dugin
9ea795b48f FEATURE (backups): Add backups cancelling 2025-11-08 20:04:06 +03:00
Rostislav Dugin
a809dc8a9c FEATURE (protection): Do not expose sensetive data of databases, notifiers and storages from API + make backups lazy loaded 2025-11-08 18:49:23 +03:00
Rostislav Dugin
bd053b51a3 FIX (workspaces): Fix switch between workspaces 2025-11-07 15:54:15 +03:00
Rostislav Dugin
431e9861f4 FEATURE (workspaces): Add workspaces with users management and global Postgresus settings 2025-11-07 15:28:03 +03:00
Rostislav Dugin
de1fd4c4da Merge pull request #56 from SebasGDEV/fix/readmetypo
fix: typo redirecting to contribute/README.me
2025-11-02 11:53:03 +03:00
Sebastian G
df55fd17d5 fix: typo redirecting to contribute/README.me 2025-11-01 11:27:58 -07:00
331 changed files with 26862 additions and 3459 deletions

View File

@@ -127,6 +127,7 @@ jobs:
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
TEST_GOOGLE_DRIVE_TOKEN_JSON=${{ secrets.TEST_GOOGLE_DRIVE_TOKEN_JSON }}
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -136,8 +137,13 @@ jobs:
# testing S3
TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing Azure Blob
TEST_AZURITE_BLOB_PORT=10000
# testing NAS
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
EOF
- name: Start test containers
@@ -151,6 +157,7 @@ jobs:
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
@@ -160,6 +167,9 @@ jobs:
# Wait for MinIO
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
# Wait for Azurite
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
- name: Create data and temp directories
run: |
# Create directories that are used for backups and restore
@@ -182,7 +192,7 @@ jobs:
- name: Run Go tests
run: |
cd backend
go test ./internal/...
go test -p=1 -count=1 -failfast ./internal/...
- name: Stop test containers
if: always()

View File

@@ -77,7 +77,7 @@ ENV APP_VERSION=$APP_VERSION
# Set production mode for Docker containers
ENV ENV_MODE=production
# Install PostgreSQL server and client tools (versions 13-17)
# Install PostgreSQL server and client tools (versions 12-18)
RUN apt-get update && apt-get install -y --no-install-recommends \
wget ca-certificates gnupg lsb-release sudo gosu && \
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
@@ -85,7 +85,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
> /etc/apt/sources.list.d/pgdg.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
postgresql-17 postgresql-18 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
rm -rf /var/lib/apt/lists/*

View File

@@ -187,7 +187,7 @@
same "license" line as the copyright notice for easier
identification within third-party archives.
Copyright 2025 LogBull
Copyright 2025 Postgresus
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -9,7 +9,7 @@
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/RostislavDugin/postgresus)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-13%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-12%20%7C%2013%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/RostislavDugin/postgresus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/RostislavDugin/postgresus)
@@ -40,13 +40,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗄️ **Multiple Storage Destinations**
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
- **Local storage**: Keep backups on your VPS/server
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
- **Secure**: All data stays under your control
### 📱 **Smart Notifications**
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
- **Real-time updates**: Success and failure notifications
@@ -54,17 +54,31 @@
### 🐘 **PostgreSQL Support**
- **Multiple versions**: PostgreSQL 13, 14, 15, 16, 17 and 18
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
- **SSL support**: Secure connections available
- **Easy restoration**: One-click restore from any backup
### 🔒 **Backup Encryption** <a href="https://postgresus.com/encryption">(docs)</a>
- **AES-256-GCM encryption**: Enterprise-grade protection for backup files
- **Zero-trust storage**: Encrypted backups are useless so you can keep in shared storages like S3, Azure Blob Storage, etc.
- **Optionality**: Encrypted backups are optional and can be enabled or disabled if you wish
- **Download unencrypted**: You can still download unencrypted backups via the 'Download' button to use them in `pg_restore` or other tools.
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
- **Access management**: Control who can view or manage specific databases with role-based permissions
- **Audit logs**: Track all system activities and changes made by users
- **User roles**: Assign viewer, member, admin or owner roles within workspaces
### 🐳 **Self-Hosted & Secure**
- **Docker-based**: Easy deployment and management
- **Privacy-first**: All your data stays on your infrastructure
- **Open source**: Apache 2.0 licensed, inspect every line of code
### 📦 Installation
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
You have three ways to install Postgresus:
@@ -118,8 +132,6 @@ This single command will:
Create a `docker-compose.yml` file with the following configuration:
```yaml
version: "3"
services:
postgresus:
container_name: postgresus
@@ -149,14 +161,16 @@ docker compose up -d
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
### 🔑 Resetting Admin Password
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
If you need to reset the admin password, you can use the built-in password reset command:
If you need to reset the password, you can use the built-in password reset command:
```bash
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123"
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123" --email="admin"
```
Replace `admin` with the actual email address of the user whose password you want to reset.
---
## 📝 License
@@ -167,4 +181,4 @@ This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENS
## 🤝 Contributing
Contributions are welcome! Read [contributing guide](contribute/readme.md) for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
Contributions are welcome! Read <a href="https://postgresus.com/contributing">contributing guide</a> for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 791 KiB

After

Width:  |  Height:  |  Size: 913 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 34 KiB

View File

@@ -1,14 +1,16 @@
---
description:
globs:
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and *gin.Context for all routes.
Example:
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
@@ -17,24 +19,26 @@ func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
package audit_logs
import (
"net/http"
"net/http"
user_models "logbull/internal/features/users/models"
user_models "postgresus/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
@@ -52,29 +56,30 @@ func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
@@ -94,34 +99,35 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,16 +1,18 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
``````
```
package audit_logs
import (
"net/http"
user_models "logbull/internal/features/users/models"
user_models "postgresus/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -117,9 +119,11 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
ctx.JSON(http.StatusOK, response)
}
``````
```
------ backend/internal/features/audit_logs/controller_test.go ------
``````
```
package audit_logs
import (
@@ -128,12 +132,12 @@ import (
"testing"
"time"
user_enums "logbull/internal/features/users/enums"
users_middleware "logbull/internal/features/users/middleware"
users_services "logbull/internal/features/users/services"
users_testing "logbull/internal/features/users/testing"
"logbull/internal/storage"
test_utils "logbull/internal/util/testing"
user_enums "postgresus/internal/features/users/enums"
users_middleware "postgresus/internal/features/users/middleware"
users_services "postgresus/internal/features/users/services"
users_testing "postgresus/internal/features/users/testing"
"postgresus/internal/storage"
test_utils "postgresus/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -256,14 +260,16 @@ func createRouter() *gin.Engine {
return router
}
``````
```
------ backend/internal/features/audit_logs/di.go ------
``````
```
package audit_logs
import (
users_services "logbull/internal/features/users/services"
"logbull/internal/util/logger"
users_services "postgresus/internal/features/users/services"
"postgresus/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
@@ -289,9 +295,11 @@ func SetupDependencies() {
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
``````
```
------ backend/internal/features/audit_logs/dto.go ------
``````
```
package audit_logs
import "time"
@@ -309,9 +317,11 @@ type GetAuditLogsResponse struct {
Offset int `json:"offset"`
}
``````
```
------ backend/internal/features/audit_logs/models.go ------
``````
```
package audit_logs
import (
@@ -332,13 +342,15 @@ func (AuditLog) TableName() string {
return "audit_logs"
}
``````
```
------ backend/internal/features/audit_logs/repository.go ------
``````
```
package audit_logs
import (
"logbull/internal/storage"
"postgresus/internal/storage"
"time"
"github.com/google/uuid"
@@ -429,9 +441,11 @@ func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
return count, err
}
``````
```
------ backend/internal/features/audit_logs/service.go ------
``````
```
package audit_logs
import (
@@ -439,8 +453,8 @@ import (
"log/slog"
"time"
user_enums "logbull/internal/features/users/enums"
user_models "logbull/internal/features/users/models"
user_enums "postgresus/internal/features/users/enums"
user_models "postgresus/internal/features/users/models"
"github.com/google/uuid"
)
@@ -560,17 +574,19 @@ func (s *AuditLogService) GetProjectAuditLogs(
}, nil
}
``````
```
------ backend/internal/features/audit_logs/service_test.go ------
``````
```
package audit_logs
import (
"testing"
"time"
user_enums "logbull/internal/features/users/enums"
users_testing "logbull/internal/features/users/testing"
user_enums "postgresus/internal/features/users/enums"
users_testing "postgresus/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -652,4 +668,4 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
db.Create(log)
}
``````
```

View File

@@ -17,6 +17,7 @@ TEST_GOOGLE_DRIVE_CLIENT_ID=
TEST_GOOGLE_DRIVE_CLIENT_SECRET=
TEST_GOOGLE_DRIVE_TOKEN_JSON="{\"access_token\":\"ya29..."
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -27,4 +28,9 @@ TEST_POSTGRES_18_PORT=5006
TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing NAS
TEST_NAS_PORT=7006
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=
TEST_TELEGRAM_CHAT_ID=
# testing Azure Blob Storage
TEST_AZURITE_BLOB_PORT=10000

3
backend/.gitignore vendored
View File

@@ -12,4 +12,5 @@ swagger/swagger.yaml
postgresus-backend.exe
ui/build/*
pgdata-for-restore/
temp/
temp/
cmd.exe

View File

@@ -13,18 +13,22 @@ import (
"time"
"postgresus-backend/internal/config"
"postgresus-backend/internal/downdetect"
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/disk"
"postgresus-backend/internal/features/encryption/secrets"
healthcheck_attempt "postgresus-backend/internal/features/healthcheck/attempt"
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/restores"
"postgresus-backend/internal/features/storages"
system_healthcheck "postgresus-backend/internal/features/system/healthcheck"
"postgresus-backend/internal/features/users"
users_controllers "postgresus-backend/internal/features/users/controllers"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
env_utils "postgresus-backend/internal/util/env"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/logger"
@@ -61,13 +65,20 @@ func main() {
os.Exit(1)
}
// Handle password reset if flag is provided
newPassword := flag.String("new-password", "", "Set a new password for the user")
flag.Parse()
if *newPassword != "" {
resetPassword(*newPassword, log)
err = secrets.GetSecretKeyService().MigrateKeyFromDbToFileIfExist()
if err != nil {
log.Error("Failed to migrate secret key from database to file", "error", err)
os.Exit(1)
}
err = users_services.GetUserService().CreateInitialAdmin()
if err != nil {
log.Error("Failed to create initial admin", "error", err)
os.Exit(1)
}
handlePasswordReset(log)
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
@@ -91,11 +102,33 @@ func main() {
startServerWithGracefulShutdown(log, ginApp)
}
func resetPassword(newPassword string, log *slog.Logger) {
func handlePasswordReset(log *slog.Logger) {
audit_logs.SetupDependencies()
newPassword := flag.String("new-password", "", "Set a new password for the user")
email := flag.String("email", "", "Email of the user to reset password")
flag.Parse()
if *newPassword == "" {
return
}
log.Info("Found reset password command - reseting password...")
if *email == "" {
log.Info("No email provided, please provide an email via --email=\"some@email.com\" flag")
os.Exit(1)
}
resetPassword(*email, *newPassword, log)
}
func resetPassword(email string, newPassword string, log *slog.Logger) {
log.Info("Resetting password...")
userService := users.GetUserService()
err := userService.ChangePassword(newPassword)
userService := users_services.GetUserService()
err := userService.ChangeUserPasswordByEmail(email, newPassword)
if err != nil {
log.Error("Failed to reset password", "error", err)
os.Exit(1)
@@ -146,37 +179,44 @@ func setUpRoutes(r *gin.Engine) {
// Mount Swagger UI
v1.GET("/docs/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
downdetectContoller := downdetect.GetDowndetectController()
userController := users.GetUserController()
notifierController := notifiers.GetNotifierController()
storageController := storages.GetStorageController()
databaseController := databases.GetDatabaseController()
backupController := backups.GetBackupController()
restoreController := restores.GetRestoreController()
healthcheckController := system_healthcheck.GetHealthcheckController()
healthcheckConfigController := healthcheck_config.GetHealthcheckConfigController()
healthcheckAttemptController := healthcheck_attempt.GetHealthcheckAttemptController()
diskController := disk.GetDiskController()
backupConfigController := backups_config.GetBackupConfigController()
downdetectContoller.RegisterRoutes(v1)
// Public routes (only user auth routes and healthcheck should be public)
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
notifierController.RegisterRoutes(v1)
storageController.RegisterRoutes(v1)
databaseController.RegisterRoutes(v1)
backupController.RegisterRoutes(v1)
restoreController.RegisterRoutes(v1)
healthcheckController.RegisterRoutes(v1)
diskController.RegisterRoutes(v1)
healthcheckConfigController.RegisterRoutes(v1)
healthcheckAttemptController.RegisterRoutes(v1)
backupConfigController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
// Protected routes
protected := v1.Group("")
protected.Use(authMiddleware)
userController.RegisterProtectedRoutes(protected)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected)
workspaces_controllers.GetMembershipController().RegisterRoutes(protected)
disk.GetDiskController().RegisterRoutes(protected)
notifiers.GetNotifierController().RegisterRoutes(protected)
storages.GetStorageController().RegisterRoutes(protected)
databases.GetDatabaseController().RegisterRoutes(protected)
backups.GetBackupController().RegisterRoutes(protected)
restores.GetRestoreController().RegisterRoutes(protected)
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
backups_config.GetBackupConfigController().RegisterRoutes(protected)
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
}
func setUpDependencies() {
databases.SetupDependencies()
backups.SetupDependencies()
restores.SetupDependencies()
healthcheck_config.SetupDependencies()
audit_logs.SetupDependencies()
notifiers.SetupDependencies()
storages.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {

View File

@@ -31,7 +31,26 @@ services:
container_name: test-minio
command: server /data --console-address ":9001"
# Test Azurite container
test-azurite:
image: mcr.microsoft.com/azure-storage/azurite
ports:
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
container_name: test-azurite
command: azurite-blob --blobHost 0.0.0.0
# Test PostgreSQL containers
test-postgres-12:
image: postgres:12
ports:
- "${TEST_POSTGRES_12_PORT}:5432"
environment:
- POSTGRES_DB=testdb
- POSTGRES_USER=testuser
- POSTGRES_PASSWORD=testpassword
container_name: test-postgres-12
shm_size: 1gb
test-postgres-13:
image: postgres:13
ports:

View File

@@ -3,6 +3,8 @@ module postgresus-backend
go 1.23.3
require (
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
github.com/gin-contrib/cors v1.7.5
github.com/gin-contrib/gzip v1.2.3
github.com/gin-gonic/gin v1.10.0
@@ -15,16 +17,18 @@ require (
github.com/lib/pq v1.10.9
github.com/minio/minio-go/v7 v7.0.92
github.com/shirou/gopsutil/v4 v4.25.5
github.com/stretchr/testify v1.10.0
github.com/stretchr/testify v1.11.1
github.com/swaggo/files v1.0.1
github.com/swaggo/gin-swagger v1.6.0
github.com/swaggo/swag v1.16.4
golang.org/x/crypto v0.39.0
golang.org/x/crypto v0.41.0
golang.org/x/time v0.12.0
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.26.1
)
require github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
require (
cloud.google.com/go/auth v0.16.2 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
@@ -99,12 +103,12 @@ require (
go.opentelemetry.io/otel/metric v1.36.0 // indirect
go.opentelemetry.io/otel/trace v1.36.0 // indirect
golang.org/x/arch v0.17.0 // indirect
golang.org/x/net v0.41.0 // indirect
golang.org/x/net v0.43.0 // indirect
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.15.0 // indirect
golang.org/x/sys v0.33.0 // indirect
golang.org/x/text v0.26.0 // indirect
golang.org/x/tools v0.33.0 // indirect
golang.org/x/sync v0.16.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
golang.org/x/tools v0.35.0 // indirect
google.golang.org/api v0.239.0
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect

View File

@@ -6,6 +6,18 @@ cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeO
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 h1:KpMC6LFL7mqpExyMC9jVOYRiVhLmamjeZfRsUpB7l4s=
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig=
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4=
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8=
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI=
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
@@ -80,6 +92,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
@@ -131,6 +145,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@@ -159,6 +175,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY=
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
@@ -180,8 +198,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
@@ -216,25 +234,25 @@ golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -247,8 +265,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@@ -257,15 +275,15 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo=
google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=

View File

@@ -26,13 +26,15 @@ type EnvVariables struct {
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
DataFolder string
TempFolder string
DataFolder string
TempFolder string
SecretKeyPath string
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
TestPostgres12Port string `env:"TEST_POSTGRES_12_PORT"`
TestPostgres13Port string `env:"TEST_POSTGRES_13_PORT"`
TestPostgres14Port string `env:"TEST_POSTGRES_14_PORT"`
TestPostgres15Port string `env:"TEST_POSTGRES_15_PORT"`
@@ -43,7 +45,19 @@ type EnvVariables struct {
TestMinioPort string `env:"TEST_MINIO_PORT"`
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
TestNASPort string `env:"TEST_NAS_PORT"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
}
var (
@@ -133,8 +147,13 @@ func loadEnvVariables() {
// (projectRoot/postgresus-data -> /postgresus-data)
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "backups")
env.TempFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "temp")
env.SecretKeyPath = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "secret.key")
if env.IsTesting {
if env.TestPostgres12Port == "" {
log.Error("TEST_POSTGRES_12_PORT is empty")
os.Exit(1)
}
if env.TestPostgres13Port == "" {
log.Error("TEST_POSTGRES_13_PORT is empty")
os.Exit(1)
@@ -169,10 +188,25 @@ func loadEnvVariables() {
os.Exit(1)
}
if env.TestAzuriteBlobPort == "" {
log.Error("TEST_AZURITE_BLOB_PORT is empty")
os.Exit(1)
}
if env.TestNASPort == "" {
log.Error("TEST_NAS_PORT is empty")
os.Exit(1)
}
if env.TestTelegramBotToken == "" {
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
os.Exit(1)
}
if env.TestTelegramChatID == "" {
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
os.Exit(1)
}
}
log.Info("Environment variables loaded successfully!")

View File

@@ -1,37 +0,0 @@
package downdetect
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
)
type DowndetectController struct {
service *DowndetectService
}
func (c *DowndetectController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/downdetect/is-available", c.IsAvailable)
}
// @Summary Check API availability
// @Description Checks if the API service is available
// @Tags downdetect
// @Accept json
// @Produce json
// @Success 200
// @Failure 500
// @Router /downdetect/api [get]
func (c *DowndetectController) IsAvailable(ctx *gin.Context) {
err := c.service.IsDbAvailable()
if err != nil {
ctx.JSON(
http.StatusInternalServerError,
gin.H{"error": fmt.Sprintf("Database is not available: %v", err)},
)
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "API and DB are available"})
}

View File

@@ -1,10 +0,0 @@
package downdetect
var downdetectService = &DowndetectService{}
var downdetectController = &DowndetectController{
downdetectService,
}
func GetDowndetectController() *DowndetectController {
return downdetectController
}

View File

@@ -1,17 +0,0 @@
package downdetect
import (
"postgresus-backend/internal/storage"
)
type DowndetectService struct {
}
func (s *DowndetectService) IsDbAvailable() error {
err := storage.GetDb().Exec("SELECT 1").Error
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,111 @@
package audit_logs
import (
"net/http"
user_models "postgresus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -0,0 +1,154 @@
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
workspaceID := uuid.New()
testID := uuid.New().String()
// Create test logs with unique identifiers
userLogMessage := fmt.Sprintf("Test log with user %s", testID)
workspaceLogMessage := fmt.Sprintf("Test log with workspace %s", testID)
standaloneLogMessage := fmt.Sprintf("Test log standalone %s", testID)
createAuditLog(service, userLogMessage, &adminUser.UserID, nil)
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
createAuditLog(service, standaloneLogMessage, nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=100", "Bearer "+adminUser.Token, http.StatusOK, &response)
// Verify our specific test logs are present
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, userLogMessage)
assert.Contains(t, messages, workspaceLogMessage)
assert.Contains(t, messages, standaloneLogMessage)
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
workspaceID := uuid.New()
testID := uuid.New().String()
// Create test logs for different users with unique identifiers
user1FirstMessage := fmt.Sprintf("Test log user1 first %s", testID)
user1SecondMessage := fmt.Sprintf("Test log user1 second %s", testID)
user2FirstMessage := fmt.Sprintf("Test log user2 first %s", testID)
user2SecondMessage := fmt.Sprintf("Test log user2 second %s", testID)
workspaceLogMessage := fmt.Sprintf("Test workspace log %s", testID)
createAuditLog(service, user1FirstMessage, &user1.UserID, nil)
createAuditLog(service, user1SecondMessage, &user1.UserID, &workspaceID)
createAuditLog(service, user2FirstMessage, &user2.UserID, nil)
createAuditLog(service, user2SecondMessage, &user2.UserID, &workspaceID)
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
// Verify user1's specific logs are present
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, user1FirstMessage)
assert.Contains(t, messages, user1SecondMessage)
// Count only our test logs for user1
testLogsCount := 0
for _, message := range messages {
if message == user1FirstMessage || message == user1SecondMessage {
testLogsCount++
}
}
assert.Equal(t, 2, testLogsCount)
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
// Verify user2's specific logs are present
ownMessages := extractMessages(ownLogsResponse.AuditLogs)
assert.Contains(t, ownMessages, user2FirstMessage)
assert.Contains(t, ownMessages, user2SecondMessage)
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_GetGlobalAuditLogs_WithBeforeDateFilter_ReturnsFilteredLogs(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
baseTime := time.Now().UTC()
// Set filter time to 30 minutes ago
beforeTime := baseTime.Add(-30 * time.Minute)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/audit-logs/global?beforeDate=%s&limit=1000",
beforeTime.Format(time.RFC3339),
),
"Bearer "+adminUser.Token,
http.StatusOK,
&filteredResponse,
)
// Verify ALL returned logs are older than the filter time
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime),
fmt.Sprintf("Log created at %s should be before filter time %s",
log.CreatedAt.Format(time.RFC3339), beforeTime.Format(time.RFC3339)))
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}

View File

@@ -0,0 +1,29 @@
package audit_logs
import (
users_services "postgresus-backend/internal/features/users/services"
"postgresus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}

View File

@@ -0,0 +1,31 @@
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLogDTO `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type AuditLogDTO struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
UserEmail *string `json:"userEmail" gorm:"column:user_email"`
UserName *string `json:"userName" gorm:"column:user_name"`
WorkspaceName *string `json:"workspaceName" gorm:"column:workspace_name"`
}

View File

@@ -0,0 +1,19 @@
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}

View File

@@ -0,0 +1,139 @@
package audit_logs
import (
"postgresus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id`
args := []interface{}{}
if beforeDate != nil {
sql += " WHERE al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.user_id = ?`
args := []interface{}{userID}
if beforeDate != nil {
sql += " AND al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByWorkspace(
workspaceID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.workspace_id = ?`
args := []interface{}{workspaceID}
if beforeDate != nil {
sql += " AND al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,137 @@
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
user_models "postgresus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
workspaceID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
WorkspaceID: workspaceID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(
targetUserID,
limit,
offset,
request.BeforeDate,
)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetWorkspaceAuditLogs(
workspaceID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByWorkspace(
workspaceID,
limit,
offset,
request.BeforeDate,
)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}

View File

@@ -0,0 +1,83 @@
package audit_logs
import (
"testing"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
workspace1ID, workspace2ID := uuid.New(), uuid.New()
// Create test logs for workspaces
createAuditLog(service, "Test workspace1 log first", &user1.UserID, &workspace1ID)
createAuditLog(service, "Test workspace1 log second", &user2.UserID, &workspace1ID)
createAuditLog(service, "Test workspace2 log first", &user1.UserID, &workspace2ID)
createAuditLog(service, "Test workspace2 log second", &user2.UserID, &workspace2ID)
createAuditLog(service, "Test no workspace log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test workspace 1 logs
workspace1Response, err := service.GetWorkspaceAuditLogs(workspace1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(workspace1Response.AuditLogs))
messages := extractMessages(workspace1Response.AuditLogs)
assert.Contains(t, messages, "Test workspace1 log first")
assert.Contains(t, messages, "Test workspace1 log second")
for _, log := range workspace1Response.AuditLogs {
assert.Equal(t, &workspace1ID, log.WorkspaceID)
}
// Test workspace 2 logs
workspace2Response, err := service.GetWorkspaceAuditLogs(workspace2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(workspace2Response.AuditLogs))
messages2 := extractMessages(workspace2Response.AuditLogs)
assert.Contains(t, messages2, "Test workspace2 log first")
assert.Contains(t, messages2, "Test workspace2 log second")
// Test pagination
limitedResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
assert.NotNil(t, log.UserEmail, "User email should be present for logs with user_id")
assert.NotNil(
t,
log.WorkspaceName,
"Workspace name should be present for logs with workspace_id",
)
}
}
func createAuditLog(service *AuditLogService, message string, userID, workspaceID *uuid.UUID) {
service.WriteAuditLog(message, userID, workspaceID)
}
func extractMessages(logs []*AuditLogDTO) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}

View File

@@ -5,6 +5,7 @@ import (
"postgresus-backend/internal/config"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/period"
"time"
)
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
continue
}
err = storage.DeleteFile(backup.ID)
encryptor := encryption.GetFieldEncryptor()
err = storage.DeleteFile(encryptor, backup.ID)
if err != nil {
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
}

View File

@@ -6,7 +6,9 @@ import (
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/period"
"testing"
"time"
@@ -16,10 +18,12 @@ import (
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -40,11 +44,8 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// add old backup
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -67,16 +68,20 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -97,11 +102,8 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -124,16 +126,20 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -157,11 +163,8 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -185,16 +188,20 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -218,11 +225,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -246,16 +250,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -280,11 +288,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -309,6 +314,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -0,0 +1,47 @@
package backups
import (
"context"
"errors"
"sync"
"github.com/google/uuid"
)
type BackupContextManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if !exists {
return errors.New("backup is not in progress or already completed")
}
cancelFunc()
delete(m.cancelFuncs, backupID)
return nil
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
}

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"io"
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -12,7 +12,6 @@ import (
type BackupController struct {
backupService *BackupService
userService *users.UserService
}
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
@@ -20,51 +19,48 @@ func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backups", c.MakeBackup)
router.GET("/backups/:id/file", c.GetFile)
router.DELETE("/backups/:id", c.DeleteBackup)
router.POST("/backups/:id/cancel", c.CancelBackup)
}
// GetBackups
// @Summary Get backups for a database
// @Description Get all backups for the specified database
// @Description Get paginated backups for the specified database
// @Tags backups
// @Produce json
// @Param database_id query string true "Database ID"
// @Success 200 {array} Backup
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetBackupsResponse
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups [get]
func (c *BackupController) GetBackups(ctx *gin.Context) {
databaseIDStr := ctx.Query("database_id")
if databaseIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "database_id query parameter is required"})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
databaseID, err := uuid.Parse(databaseIDStr)
var request GetBackupsRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
databaseID, err := uuid.Parse(request.DatabaseID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
backups, err := c.backupService.GetBackups(user, databaseID)
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, backups)
ctx.JSON(http.StatusOK, response)
}
// MakeBackup
@@ -80,24 +76,18 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
// @Failure 500
// @Router /backups [post]
func (c *BackupController) MakeBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request MakeBackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -117,24 +107,18 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
// @Failure 500
// @Router /backups/{id} [delete]
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.backupService.DeleteBackup(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -143,6 +127,37 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
ctx.Status(http.StatusNoContent)
}
// CancelBackup
// @Summary Cancel an in-progress backup
// @Description Cancel a backup that is currently in progress
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/cancel [post]
func (c *BackupController) CancelBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
if err := c.backupService.CancelBackup(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup
@@ -154,24 +169,18 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
fileReader, err := c.backupService.GetBackupFile(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -179,19 +188,16 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
}
defer func() {
if err := fileReader.Close(); err != nil {
// Log the error but don't interrupt the response
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
// Set headers for file download
ctx.Header("Content-Type", "application/octet-stream")
ctx.Header(
"Content-Disposition",
fmt.Sprintf("attachment; filename=\"backup_%s.dump\"", id.String()),
)
// Stream the file content
_, err = io.Copy(ctx.Writer, fileReader)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})

View File

@@ -0,0 +1,709 @@
package backups
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
audit_logs "postgresus-backend/internal/features/audit_logs"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/storages"
local_storage "postgresus-backend/internal/features/storages/models/local"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func Test_GetBackups_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace viewer can get backups",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get backups",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get backups",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can get backups",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups?database_id=%s", database.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if tt.expectSuccess {
var response GetBackupsResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(response.Backups), 1)
assert.GreaterOrEqual(t, response.Total, int64(1))
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot create backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can create backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
request := MakeBackupRequest{DatabaseID: database.ID}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
if tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "backup started successfully")
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CreateBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
request := MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+owner.Token,
request,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup manually initiated") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup creation not found")
}
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace member can delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace viewer cannot delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "non-member cannot delete backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can delete backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
} else {
userService := users_services.GetUserService()
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
})
}
}
func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusNoContent,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup deleted") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup deletion not found")
}
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace viewer can download backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can download backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot download backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can download backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup file downloaded") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup download not found")
}
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/cancel", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
// Verify audit log was created
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
userService := users_services.GetUserService()
adminUser, err := userService.GetUserFromToken(admin.Token)
assert.NoError(t, err)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetGlobalAuditLogs(
adminUser,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
func createTestDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,
Type: storages.StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
}
repo := &storages.StorageRepository{}
storage, err := repo.Save(storage)
if err != nil {
panic(err)
}
return storage
}
func enableBackupForDatabase(databaseID uuid.UUID) {
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(databaseID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
}
func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
backup := createTestBackup(database, owner)
return database, backup
}
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
// Create a dummy backup file for testing download functionality
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}

View File

@@ -1,17 +1,24 @@
package backups
import (
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/usecases"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"time"
)
var backupRepository = &BackupRepository{}
var backupContextManager = NewBackupContextManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
@@ -19,9 +26,14 @@ var backupService = &BackupService{
notifiers.GetNotifierService(),
notifiers.GetNotifierService(),
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
}
var backupBackgroundService = &BackupBackgroundService{
@@ -35,7 +47,6 @@ var backupBackgroundService = &BackupBackgroundService{
var backupController = &BackupController{
backupService,
users.GetUserService(),
}
func SetupDependencies() {

View File

@@ -0,0 +1,28 @@
package backups
import (
"io"
"postgresus-backend/internal/features/backups/backups/encryption"
)
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type decryptionReaderCloser struct {
*encryption.DecryptionReader
baseReader io.ReadCloser
}
func (r *decryptionReaderCloser) Close() error {
return r.baseReader.Close()
}

View File

@@ -0,0 +1,156 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type DecryptionReader struct {
baseReader io.Reader
cipher cipher.AEAD
buffer []byte
nonce []byte
chunkIndex uint64
headerRead bool
eof bool
}
func NewDecryptionReader(
baseReader io.Reader,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*DecryptionReader, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
reader := &DecryptionReader{
baseReader,
aesgcm,
make([]byte, 0),
nonce,
0,
false,
false,
}
if err := reader.readAndValidateHeader(salt, nonce); err != nil {
return nil, err
}
return reader, nil
}
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
for len(r.buffer) < len(p) && !r.eof {
if err := r.readAndDecryptChunk(); err != nil {
if err == io.EOF {
r.eof = true
break
}
return 0, err
}
}
if len(r.buffer) == 0 {
return 0, io.EOF
}
n = copy(p, r.buffer)
r.buffer = r.buffer[n:]
return n, nil
}
func (r *DecryptionReader) readAndValidateHeader(expectedSalt, expectedNonce []byte) error {
header := make([]byte, HeaderLen)
if _, err := io.ReadFull(r.baseReader, header); err != nil {
return fmt.Errorf("failed to read header: %w", err)
}
magic := string(header[0:MagicBytesLen])
if magic != MagicBytes {
return fmt.Errorf("invalid magic bytes: expected %s, got %s", MagicBytes, magic)
}
salt := header[MagicBytesLen : MagicBytesLen+SaltLen]
nonce := header[MagicBytesLen+SaltLen : MagicBytesLen+SaltLen+NonceLen]
if string(salt) != string(expectedSalt) {
return fmt.Errorf("salt mismatch in file header")
}
if string(nonce) != string(expectedNonce) {
return fmt.Errorf("nonce mismatch in file header")
}
r.headerRead = true
return nil
}
func (r *DecryptionReader) readAndDecryptChunk() error {
lengthBuf := make([]byte, 4)
if _, err := io.ReadFull(r.baseReader, lengthBuf); err != nil {
return err
}
chunkLen := binary.BigEndian.Uint32(lengthBuf)
if chunkLen == 0 || chunkLen > ChunkSize+16 {
return fmt.Errorf("invalid chunk length: %d", chunkLen)
}
encrypted := make([]byte, chunkLen)
if _, err := io.ReadFull(r.baseReader, encrypted); err != nil {
return fmt.Errorf("failed to read encrypted chunk: %w", err)
}
chunkNonce := r.generateChunkNonce()
decrypted, err := r.cipher.Open(nil, chunkNonce, encrypted, nil)
if err != nil {
return fmt.Errorf(
"failed to decrypt chunk (authentication failed - file may be corrupted or tampered): %w",
err,
)
}
r.buffer = append(r.buffer, decrypted...)
r.chunkIndex++
return nil
}
func (r *DecryptionReader) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, r.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], r.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,147 @@
package encryption
import (
"crypto/aes"
"crypto/cipher"
"encoding/binary"
"fmt"
"io"
"github.com/google/uuid"
)
type EncryptionWriter struct {
baseWriter io.Writer
cipher cipher.AEAD
buffer []byte
nonce []byte
salt []byte
chunkIndex uint64
headerWritten bool
}
func NewEncryptionWriter(
baseWriter io.Writer,
masterKey string,
backupID uuid.UUID,
salt []byte,
nonce []byte,
) (*EncryptionWriter, error) {
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
}
if len(nonce) != NonceLen {
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
}
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
if err != nil {
return nil, fmt.Errorf("failed to derive backup key: %w", err)
}
block, err := aes.NewCipher(derivedKey)
if err != nil {
return nil, fmt.Errorf("failed to create cipher: %w", err)
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
writer := &EncryptionWriter{
baseWriter: baseWriter,
cipher: aesgcm,
buffer: make([]byte, 0, ChunkSize),
nonce: nonce,
chunkIndex: 0,
headerWritten: false,
salt: salt, // Store salt for lazy header writing
}
return writer, nil
}
func (w *EncryptionWriter) Write(p []byte) (n int, err error) {
// Write header on first write (lazy initialization)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return 0, fmt.Errorf("failed to write header: %w", err)
}
}
n = len(p)
w.buffer = append(w.buffer, p...)
for len(w.buffer) >= ChunkSize {
chunk := w.buffer[:ChunkSize]
if err := w.encryptAndWriteChunk(chunk); err != nil {
return 0, err
}
w.buffer = w.buffer[ChunkSize:]
}
return n, nil
}
func (w *EncryptionWriter) Close() error {
// Write header if it hasn't been written yet (in case Close is called without any writes)
if !w.headerWritten {
if err := w.writeHeader(w.salt, w.nonce); err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
}
if len(w.buffer) > 0 {
if err := w.encryptAndWriteChunk(w.buffer); err != nil {
return err
}
w.buffer = nil
}
return nil
}
func (w *EncryptionWriter) writeHeader(salt, nonce []byte) error {
header := make([]byte, HeaderLen)
copy(header[0:MagicBytesLen], []byte(MagicBytes))
copy(header[MagicBytesLen:MagicBytesLen+SaltLen], salt)
copy(header[MagicBytesLen+SaltLen:MagicBytesLen+SaltLen+NonceLen], nonce)
_, err := w.baseWriter.Write(header)
if err != nil {
return fmt.Errorf("failed to write header: %w", err)
}
w.headerWritten = true
return nil
}
func (w *EncryptionWriter) encryptAndWriteChunk(chunk []byte) error {
chunkNonce := w.generateChunkNonce()
encrypted := w.cipher.Seal(nil, chunkNonce, chunk, nil)
lengthBuf := make([]byte, 4)
binary.BigEndian.PutUint32(lengthBuf, uint32(len(encrypted)))
if _, err := w.baseWriter.Write(lengthBuf); err != nil {
return fmt.Errorf("failed to write chunk length: %w", err)
}
if _, err := w.baseWriter.Write(encrypted); err != nil {
return fmt.Errorf("failed to write encrypted chunk: %w", err)
}
w.chunkIndex++
return nil
}
func (w *EncryptionWriter) generateChunkNonce() []byte {
chunkNonce := make([]byte, NonceLen)
copy(chunkNonce, w.nonce)
binary.BigEndian.PutUint64(chunkNonce[4:], w.chunkIndex)
return chunkNonce
}

View File

@@ -0,0 +1,387 @@
package encryption
import (
"bytes"
"crypto/rand"
"io"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_EncryptDecryptRoundTrip_ReturnsOriginalData(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte(
"This is a test backup data that should be encrypted and then decrypted successfully.",
)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted := make([]byte, len(originalData))
n, err = io.ReadFull(reader, decrypted)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptDecryptRoundTrip_LargeData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := make([]byte, 100*1024)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
n, err := writer.Write(originalData)
require.NoError(t, err)
assert.Equal(t, len(originalData), n)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_EncryptionWriter_MultipleWrites_CombinesCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
part1 := []byte("First part of data. ")
part2 := []byte("Second part of data. ")
part3 := []byte("Third part of data.")
expectedData := append(append(part1, part2...), part3...)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(part1)
require.NoError(t, err)
_, err = writer.Write(part2)
require.NoError(t, err)
_, err = writer.Write(part3)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, expectedData, decrypted)
}
func Test_DecryptionReader_InvalidHeader_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
invalidHeader := make([]byte, HeaderLen)
copy(invalidHeader, []byte("INVALID!"))
invalidData := bytes.NewBuffer(invalidHeader)
_, err = NewDecryptionReader(invalidData, masterKey, backupID, salt, nonce)
assert.Error(t, err)
assert.Contains(t, err.Error(), "invalid magic bytes")
}
func Test_DecryptionReader_TamperedData_ReturnsError(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This data will be tampered with.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
encryptedBytes := encrypted.Bytes()
if len(encryptedBytes) > HeaderLen+10 {
encryptedBytes[HeaderLen+10] ^= 0xFF
}
tamperedBuffer := bytes.NewBuffer(encryptedBytes)
reader, err := NewDecryptionReader(tamperedBuffer, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_DeriveBackupKey_SameInputs_ReturnsSameKey(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey, backupID, salt)
require.NoError(t, err)
assert.Equal(t, key1, key2)
}
func Test_DeriveBackupKey_DifferentInputs_ReturnsDifferentKeys(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID1 := uuid.New()
backupID2 := uuid.New()
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
key1, err := DeriveBackupKey(masterKey1, backupID1, salt1)
require.NoError(t, err)
key2, err := DeriveBackupKey(masterKey2, backupID1, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key2)
key3, err := DeriveBackupKey(masterKey1, backupID2, salt1)
require.NoError(t, err)
assert.NotEqual(t, key1, key3)
key4, err := DeriveBackupKey(masterKey1, backupID1, salt2)
require.NoError(t, err)
assert.NotEqual(t, key1, key4)
}
func Test_EncryptionWriter_PartialChunk_HandledCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
smallData := []byte("Small data less than chunk size")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(smallData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, smallData, decrypted)
}
func Test_GenerateSalt_ReturnsCorrectLength(t *testing.T) {
salt, err := GenerateSalt()
require.NoError(t, err)
assert.Equal(t, SaltLen, len(salt))
}
func Test_GenerateSalt_GeneratesUniqueSalts(t *testing.T) {
salt1, err := GenerateSalt()
require.NoError(t, err)
salt2, err := GenerateSalt()
require.NoError(t, err)
assert.NotEqual(t, salt1, salt2)
}
func Test_GenerateNonce_ReturnsCorrectLength(t *testing.T) {
nonce, err := GenerateNonce()
require.NoError(t, err)
assert.Equal(t, NonceLen, len(nonce))
}
func Test_GenerateNonce_GeneratesUniqueNonces(t *testing.T) {
nonce1, err := GenerateNonce()
require.NoError(t, err)
nonce2, err := GenerateNonce()
require.NoError(t, err)
assert.NotEqual(t, nonce1, nonce2)
}
func Test_DecryptionReader_WrongMasterKey_ReturnsError(t *testing.T) {
masterKey1 := uuid.New().String() + uuid.New().String()
masterKey2 := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("Secret data")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey1, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey2, backupID, salt, nonce)
require.NoError(t, err)
_, err = io.ReadAll(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "authentication failed")
}
func Test_EncryptionWriter_EmptyData_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, 0, len(decrypted))
}
func Test_EncryptionWriter_MultipleChunks_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
dataSize := ChunkSize*3 + 1000
originalData := make([]byte, dataSize)
_, err = rand.Read(originalData)
require.NoError(t, err)
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
decrypted, err := io.ReadAll(reader)
require.NoError(t, err)
assert.Equal(t, originalData, decrypted)
}
func Test_DecryptionReader_SmallReads_WorksCorrectly(t *testing.T) {
masterKey := uuid.New().String() + uuid.New().String()
backupID := uuid.New()
salt, err := GenerateSalt()
require.NoError(t, err)
nonce, err := GenerateNonce()
require.NoError(t, err)
originalData := []byte("This is test data that will be read in small chunks.")
var encrypted bytes.Buffer
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
_, err = writer.Write(originalData)
require.NoError(t, err)
err = writer.Close()
require.NoError(t, err)
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
require.NoError(t, err)
var decrypted []byte
buf := make([]byte, 5)
for {
n, err := reader.Read(buf)
if n > 0 {
decrypted = append(decrypted, buf[:n]...)
}
if err == io.EOF {
break
}
require.NoError(t, err)
}
assert.Equal(t, originalData, decrypted)
}

View File

@@ -0,0 +1,52 @@
package encryption
import (
"crypto/rand"
"crypto/sha256"
"fmt"
"github.com/google/uuid"
"golang.org/x/crypto/pbkdf2"
)
const (
MagicBytes = "PGRSUS01"
MagicBytesLen = 8
SaltLen = 32
NonceLen = 12
ReservedLen = 12
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
ChunkSize = 32 * 1024
PBKDF2Iterations = 100000
)
func DeriveBackupKey(masterKey string, backupID uuid.UUID, salt []byte) ([]byte, error) {
if masterKey == "" {
return nil, fmt.Errorf("master key cannot be empty")
}
if len(salt) != SaltLen {
return nil, fmt.Errorf("salt must be %d bytes", SaltLen)
}
keyMaterial := []byte(masterKey + backupID.String())
derivedKey := pbkdf2.Key(keyMaterial, salt, PBKDF2Iterations, 32, sha256.New)
return derivedKey, nil
}
func GenerateSalt() ([]byte, error) {
salt := make([]byte, SaltLen)
if _, err := rand.Read(salt); err != nil {
return nil, fmt.Errorf("failed to generate salt: %w", err)
}
return salt, nil
}
func GenerateNonce() ([]byte, error) {
nonce := make([]byte, NonceLen)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("failed to generate nonce: %w", err)
}
return nonce, nil
}

View File

@@ -6,4 +6,5 @@ const (
BackupStatusInProgress BackupStatus = "IN_PROGRESS"
BackupStatusCompleted BackupStatus = "COMPLETED"
BackupStatusFailed BackupStatus = "FAILED"
BackupStatusCanceled BackupStatus = "CANCELED"
)

View File

@@ -1,6 +1,9 @@
package backups
import (
"context"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
@@ -19,6 +22,7 @@ type NotificationSender interface {
type CreateBackupUsecase interface {
Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -26,7 +30,7 @@ type CreateBackupUsecase interface {
backupProgressListener func(
completedMBs float64,
),
) error
) (*usecases_postgresql.BackupMetadata, error)
}
type BackupRemoveListener interface {

View File

@@ -1,8 +1,7 @@
package backups
import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/storages"
backups_config "postgresus-backend/internal/features/backups/config"
"time"
"github.com/google/uuid"
@@ -11,11 +10,8 @@ import (
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Database *databases.Database `json:"database" gorm:"foreignKey:DatabaseID"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
@@ -24,5 +20,9 @@ type Backup struct {
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}

View File

@@ -13,18 +13,20 @@ import (
type BackupRepository struct{}
func (r *BackupRepository) Save(backup *Backup) error {
if backup.DatabaseID == uuid.Nil || backup.StorageID == uuid.Nil {
return errors.New("database ID and storage ID are required")
}
db := storage.GetDb()
isNew := backup.ID == uuid.Nil
if isNew {
backup.ID = uuid.New()
return db.Create(backup).
Omit("Database", "Storage").
Error
}
return db.Save(backup).
Omit("Database", "Storage").
Error
}
@@ -33,8 +35,6 @@ func (r *BackupRepository) FindByDatabaseID(databaseID uuid.UUID) ([]*Backup, er
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -56,8 +56,6 @@ func (r *BackupRepository) FindByDatabaseIDWithLimit(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).
@@ -73,8 +71,6 @@ func (r *BackupRepository) FindByStorageID(storageID uuid.UUID) ([]*Backup, erro
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ?", storageID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -89,8 +85,6 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&backup).Error; err != nil {
@@ -109,8 +103,6 @@ func (r *BackupRepository) FindByID(id uuid.UUID) (*Backup, error) {
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("id = ?", id).
First(&backup).Error; err != nil {
return nil, err
@@ -124,8 +116,6 @@ func (r *BackupRepository) FindByStatus(status BackupStatus) ([]*Backup, error)
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("status = ?", status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -143,8 +133,6 @@ func (r *BackupRepository) FindByStorageIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ? AND status = ?", storageID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -162,8 +150,6 @@ func (r *BackupRepository) FindByDatabaseIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND status = ?", databaseID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -185,8 +171,6 @@ func (r *BackupRepository) FindBackupsBeforeDate(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND created_at < ?", databaseID, date).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -195,3 +179,36 @@ func (r *BackupRepository) FindBackupsBeforeDate(
return backups, nil
}
func (r *BackupRepository) FindByDatabaseIDWithPagination(
databaseID uuid.UUID,
limit, offset int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
if err := storage.
GetDb().
Model(&Backup{}).
Where("database_id = ?", databaseID).
Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}

View File

@@ -1,17 +1,26 @@
package backups
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"log/slog"
"slices"
"strings"
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
"slices"
"time"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
util_encryption "postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -23,12 +32,18 @@ type BackupService struct {
notifierService *notifiers.NotifierService
notificationSender NotificationSender
backupConfigService *backups_config.BackupConfigService
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor util_encryption.FieldEncryptor
createBackupUseCase CreateBackupUsecase
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextManager *BackupContextManager
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
@@ -62,34 +77,74 @@ func (s *BackupService) MakeBackupWithAuth(
return err
}
if database.UserID != user.ID {
return errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot create backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}
func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
) ([]*Backup, error) {
limit, offset int,
) (*GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot get backups for database without workspace")
}
backups, err := s.backupRepository.FindByDatabaseID(databaseID)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access backups for this database")
}
if limit <= 0 {
limit = 10
}
if offset < 0 {
offset = 0
}
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
if err != nil {
return nil, err
}
return backups, nil
total, err := s.backupRepository.CountByDatabaseID(databaseID)
if err != nil {
return nil, err
}
return &GetBackupsResponse{
Backups: backups,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *BackupService) DeleteBackup(
@@ -101,14 +156,37 @@ func (s *BackupService) DeleteBackup(
return err
}
if backup.Database.UserID != user.ID {
return errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot delete backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
return errors.New("backup is in progress")
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return s.deleteBackup(backup)
}
@@ -154,10 +232,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup := &Backup{
DatabaseID: databaseID,
Database: database,
StorageID: storage.ID,
Storage: storage,
StorageID: storage.ID,
Status: BackupStatusInProgress,
@@ -184,7 +259,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
}
}
err = s.createBackupUseCase.Execute(
ctx, cancel := context.WithCancel(context.Background())
s.backupContextManager.RegisterBackup(backup.ID, cancel)
defer s.backupContextManager.UnregisterBackup(backup.ID)
backupMetadata, err := s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
@@ -193,6 +273,34 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
@@ -225,6 +333,13 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup.Status = BackupStatusCompleted
backup.BackupDurationMs = time.Since(start).Milliseconds()
// Update backup with encryption metadata if provided
if backupMetadata != nil {
backup.EncryptionSalt = backupMetadata.EncryptionSalt
backup.EncryptionIV = backupMetadata.EncryptionIV
backup.Encryption = backupMetadata.Encryption
}
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save backup", "error", err)
return
@@ -265,6 +380,11 @@ func (s *BackupService) SendBackupNotification(
return
}
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
if err != nil {
return
}
for _, notifier := range database.Notifiers {
if !slices.Contains(
backupConfig.SendNotificationsOn,
@@ -276,9 +396,17 @@ func (s *BackupService) SendBackupNotification(
title := ""
switch notificationType {
case backups_config.NotificationBackupFailed:
title = fmt.Sprintf("❌ Backup failed for database \"%s\"", database.Name)
title = fmt.Sprintf(
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
case backups_config.NotificationBackupSuccess:
title = fmt.Sprintf("✅ Backup completed for database \"%s\"", database.Name)
title = fmt.Sprintf(
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
database.Name,
workspace.Name,
)
}
message := ""
@@ -319,6 +447,53 @@ func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
return s.backupRepository.FindByID(backupID)
}
func (s *BackupService) CancelBackup(
user *users_models.User,
backupID uuid.UUID,
) error {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return nil
}
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
@@ -328,16 +503,37 @@ func (s *BackupService) GetBackupFile(
return nil, err
}
if backup.Database.UserID != user.ID {
return nil, errors.New("user does not have access to this backup")
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
return storage.GetFile(backup.ID)
if database.WorkspaceID == nil {
return nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to download backup for this database")
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return s.getBackupReader(backupID)
}
func (s *BackupService) deleteBackup(backup *Backup) error {
@@ -352,9 +548,12 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
return err
}
err = storage.DeleteFile(backup.ID)
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
if err != nil {
return err
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)
@@ -389,3 +588,91 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
return nil
}
// GetBackupReader returns a reader for the backup file
// If encrypted, wraps with DecryptionReader
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return nil, fmt.Errorf("failed to find backup: %w", err)
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
if err != nil {
return nil, fmt.Errorf("failed to get storage: %w", err)
}
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
if err != nil {
return nil, fmt.Errorf("failed to get backup file: %w", err)
}
// If not encrypted, return raw reader
if backup.Encryption == backups_config.BackupEncryptionNone {
s.logger.Info("Returning non-encrypted backup", "backupId", backupID)
return fileReader, nil
}
// Decrypt on-the-fly for encrypted backups
if backup.Encryption != backups_config.BackupEncryptionEncrypted {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("unsupported encryption type: %s", backup.Encryption)
}
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
if err := fileReader.Close(); err != nil {
s.logger.Error("Failed to close file reader", "error", err)
}
return nil, fmt.Errorf("backup marked as encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := s.secretKeyService.GetSecretKey()
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to get master key: %w", err)
}
// Decode salt and IV
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to decode IV: %w", err)
}
// Wrap with decrypting reader
decryptionReader, err := encryption.NewDecryptionReader(
fileReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
if closeErr := fileReader.Close(); closeErr != nil {
s.logger.Error("Failed to close file reader", "error", closeErr)
}
return nil, fmt.Errorf("failed to create decrypting reader: %w", err)
}
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
return &decryptionReaderCloser{
decryptionReader,
fileReader,
}, nil
}

View File

@@ -1,15 +1,24 @@
package backups
import (
"context"
"errors"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
"postgresus-backend/internal/util/logger"
"strings"
"testing"
"time"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -17,15 +26,27 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer storages.RemoveTestStorage(storage.ID)
defer notifiers.RemoveTestNotifier(notifier)
defer databases.RemoveTestDatabase(database)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
@@ -36,9 +57,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
// Set up expectations
@@ -79,9 +105,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
backupService.MakeBackup(database.ID, true)
@@ -99,9 +130,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
notifiers.GetNotifierService(),
mockNotificationSender,
backups_config.GetBackupConfigService(),
encryption_secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
nil,
NewBackupContextManager(),
}
// capture arguments
@@ -137,6 +173,7 @@ type CreateFailedBackupUsecase struct {
}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -144,15 +181,16 @@ func (uc *CreateFailedBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return errors.New("backup failed")
return nil, errors.New("backup failed")
}
type CreateSuccessBackupUsecase struct {
}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -160,7 +198,11 @@ func (uc *CreateSuccessBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
backupProgressListener(10) // Assume we completed 10MB
return nil
return &usecases_postgresql.BackupMetadata{
EncryptionSalt: nil,
EncryptionIV: nil,
Encryption: backups_config.BackupEncryptionNone,
}, nil
}

View File

@@ -0,0 +1,20 @@
package backups
import (
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"github.com/gin-gonic/gin"
)
func CreateTestRouter() *gin.Engine {
return workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
GetBackupController(),
)
}

View File

@@ -1,6 +1,7 @@
package usecases
import (
"context"
"errors"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
@@ -14,8 +15,9 @@ type CreateBackupUsecase struct {
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
}
// Execute creates a backup of the database and returns the backup size in MB
// Execute creates a backup of the database and returns the backup metadata
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -23,9 +25,10 @@ func (uc *CreateBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*usecases_postgresql.BackupMetadata, error) {
if database.Type == databases.DatabaseTypePostgres {
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
backupID,
backupConfig,
database,
@@ -34,5 +37,5 @@ func (uc *CreateBackupUsecase) Execute(
)
}
return errors.New("database type not supported")
return nil, errors.New("database type not supported")
}

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -14,21 +15,39 @@ import (
"time"
"postgresus-backend/internal/config"
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"github.com/google/uuid"
)
const (
backupTimeout = 23 * time.Hour
shutdownCheckInterval = 1 * time.Second
copyBufferSize = 32 * 1024
progressReportIntervalMB = 1.0
pgConnectTimeout = 30
compressionLevel = 5
exitCodeAccessViolation = -1073741819
exitCodeGenericError = 1
exitCodeConnectionError = 2
)
type CreatePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyService *encryption_secrets.SecretKeyService
fieldEncryptor encryption.FieldEncryptor
}
// Execute creates a backup of the database
func (uc *CreatePostgresqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
@@ -36,7 +55,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
backupProgressListener func(
completedMBs float64,
),
) error {
) (*BackupMetadata, error) {
uc.logger.Info(
"Creating PostgreSQL backup via pg_dump custom format",
"databaseId",
@@ -46,41 +65,28 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
)
if !backupConfig.IsBackupsEnabled {
return fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
}
pg := db.Postgresql
if pg == nil {
return fmt.Errorf("postgresql database configuration is required for pg_dump backups")
return nil, fmt.Errorf("postgresql database configuration is required for pg_dump backups")
}
if pg.Database == nil || *pg.Database == "" {
return fmt.Errorf("database name is required for pg_dump backups")
return nil, fmt.Errorf("database name is required for pg_dump backups")
}
args := []string{
"-Fc", // custom format with built-in compression
"--no-password", // Use environment variable for password, prevent prompts
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose", // Add verbose output to help with debugging
}
args := uc.buildPgDumpArgs(pg)
// Use zstd compression level 5 for PostgreSQL 15+ (better compression and speed)
// Fall back to gzip compression level 5 for older versions
if pg.Version == tools.PostgresqlVersion13 || pg.Version == tools.PostgresqlVersion14 ||
pg.Version == tools.PostgresqlVersion15 {
args = append(args, "-Z", "5")
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
} else {
args = append(args, "--compress=zstd:5")
uc.logger.Info("Using zstd compression level 5", "version", pg.Version)
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, pg.Password)
if err != nil {
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
}
return uc.streamToStorage(
ctx,
backupID,
backupConfig,
tools.GetPostgresqlExecutable(
@@ -90,7 +96,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
config.GetEnv().PostgresesInstallDir,
),
args,
pg.Password,
decryptedPassword,
storage,
db,
backupProgressListener,
@@ -99,6 +105,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
// streamToStorage streams pg_dump output directly to storage
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
pgBin string,
@@ -107,36 +114,15 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
storage *storages.Storage,
db *databases.Database,
backupProgressListener func(completedMBs float64),
) error {
) (*BackupMetadata, error) {
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
// if backup not fit into 23 hours, Postgresus
// seems not to work for such database size
ctx, cancel := context.WithTimeout(context.Background(), 23*time.Hour)
ctx, cancel := uc.createBackupContext(parentCtx)
defer cancel()
// Monitor for shutdown and cancel context if needed
go func() {
ticker := time.NewTicker(1 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
// Create temporary .pgpass file as a more reliable alternative to PGPASSWORD
pgpassFile, err := uc.createTempPgpassFile(db.Postgresql, password)
pgpassFile, err := uc.setupPgpassFile(db.Postgresql, password)
if err != nil {
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
return nil, err
}
defer func() {
if pgpassFile != "" {
@@ -144,87 +130,21 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
}
}()
// Verify .pgpass file was created successfully
if pgpassFile == "" {
return fmt.Errorf("temporary .pgpass file was not created")
}
// Verify .pgpass file was created correctly
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return fmt.Errorf("failed to verify .pgpass file: %w", err)
}
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
// Start with system environment variables to preserve Windows PATH, SystemRoot, etc.
cmd.Env = os.Environ()
// Use the .pgpass file for authentication
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
// Debug password setup (without exposing the actual password)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", backupConfig.CpuCount,
)
// Add PostgreSQL-specific environment variables
cmd.Env = append(cmd.Env, "PGCLIENTENCODING=UTF8")
cmd.Env = append(cmd.Env, "PGCONNECT_TIMEOUT=30")
// Add encoding-related environment variables to handle character encoding issues
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
// Add PostgreSQL-specific encoding settings
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
shouldRequireSSL := db.Postgresql.IsHttps
// Require SSL when explicitly configured
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", db.Postgresql.IsHttps)
} else {
// SSL not explicitly required, but prefer it if available
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", db.Postgresql.IsHttps)
}
// Set other SSL parameters to avoid certificate issues
cmd.Env = append(cmd.Env, "PGSSLCERT=") // No client certificate
cmd.Env = append(cmd.Env, "PGSSLKEY=") // No client key
cmd.Env = append(cmd.Env, "PGSSLROOTCERT=") // No root certificate verification
cmd.Env = append(cmd.Env, "PGSSLCRL=") // No certificate revocation list
// Verify executable exists and is accessible
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf(
"PostgreSQL executable not found or not accessible: %s - %w",
pgBin,
err,
)
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, backupConfig.CpuCount, pgBin); err != nil {
return nil, err
}
pgStdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("stdout pipe: %w", err)
return nil, fmt.Errorf("stdout pipe: %w", err)
}
pgStderr, err := cmd.StderrPipe()
if err != nil {
return fmt.Errorf("stderr pipe: %w", err)
return nil, fmt.Errorf("stderr pipe: %w", err)
}
// Capture stderr in a separate goroutine to ensure we don't miss any error output
@@ -234,23 +154,31 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
stderrCh <- stderrOutput
}()
// A pipe connecting pg_dump output → storage
storageReader, storageWriter := io.Pipe()
// Create a counting writer to track bytes
countingWriter := &CountingWriter{writer: storageWriter}
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
backupID,
backupConfig,
storageWriter,
)
if err != nil {
return nil, err
}
countingWriter := &CountingWriter{writer: finalWriter}
// The backup ID becomes the object key / filename in storage
// Start streaming into storage in its own goroutine
saveErrCh := make(chan error, 1)
go func() {
saveErrCh <- storage.SaveFile(uc.logger, backupID, storageReader)
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
saveErrCh <- saveErr
}()
// Start pg_dump
if err = cmd.Start(); err != nil {
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
return nil, fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
}
// Copy pg output directly to storage with shutdown checks
@@ -272,23 +200,17 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
bytesWritten := <-bytesWrittenCh
waitErr := cmd.Wait()
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
}
<-saveErrCh // Wait for storage to finish
return fmt.Errorf("backup cancelled due to shutdown")
// Check for shutdown or cancellation before finalizing
select {
case <-ctx.Done():
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
return nil, uc.checkCancellationReason()
default:
}
// Close the pipe writer to signal end of data
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
}
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
<-saveErrCh
return nil, err
}
// Wait until storage ends reading
@@ -303,134 +225,34 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
switch {
case waitErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
// Enhanced error handling for PostgreSQL connection and SSL issues
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf(
"%s failed: %v stderr: %s",
filepath.Base(pgBin),
waitErr,
stderrStr,
)
// Check for specific PostgreSQL error patterns
if exitErr, ok := waitErr.(*exec.ExitError); ok {
exitCode := exitErr.ExitCode()
// Enhanced debugging for exit status 1 with empty stderr
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=30",
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
},
)
errorMsg = fmt.Sprintf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"args", args,
"exitCode", fmt.Sprintf("0x%X", uint32(exitCode)),
)
errorMsg = fmt.Sprintf(
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
} else if exitCode == 1 || exitCode == 2 {
// Check for common connection and authentication issues
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). The server may not allow connections from your IP address or may require different authentication settings. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. stderr: %s",
len(password),
password == "",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
errorMsg = fmt.Sprintf(
"PostgreSQL SSL connection failed. The server may require SSL encryption or have SSL configuration issues. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection refused. Check if the server is running and accessible from your network. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
errorMsg = fmt.Sprintf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
} else if containsIgnoreCase(stderrStr, "timeout") {
errorMsg = fmt.Sprintf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
}
}
return errors.New(errorMsg)
return nil, uc.buildPgDumpErrorMessage(waitErr, stderrOutput, pgBin, args, password)
case copyErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("copy to storage: %w", copyErr)
return nil, fmt.Errorf("copy to storage: %w", copyErr)
case saveErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
if err := uc.checkCancellation(ctx); err != nil {
return nil, err
}
return fmt.Errorf("save to storage: %w", saveErr)
return nil, fmt.Errorf("save to storage: %w", saveErr)
}
return nil
return &backupMetadata, nil
}
// copyWithShutdownCheck copies data from src to dst while checking for shutdown
func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
ctx context.Context,
dst io.Writer,
src io.Reader,
backupProgressListener func(completedMBs float64),
) (int64, error) {
buf := make([]byte, 32*1024) // 32KB buffer
buf := make([]byte, copyBufferSize)
var totalBytesWritten int64
// Progress reporting interval - report every 1MB of data
var lastReportedMB float64
const reportIntervalMB = 1.0
for {
select {
@@ -463,12 +285,9 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
totalBytesWritten += int64(bytesWritten)
// Report progress based on total size
if backupProgressListener != nil {
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
// Only report if we've written at least 1MB more data than last report
if currentSizeMB >= lastReportedMB+reportIntervalMB {
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
backupProgressListener(currentSizeMB)
lastReportedMB = currentSizeMB
}
@@ -479,7 +298,6 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
if readErr != io.EOF {
return totalBytesWritten, readErr
}
break
}
}
@@ -487,12 +305,412 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
return totalBytesWritten, nil
}
// containsIgnoreCase checks if a string contains a substring, ignoring case
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlDatabase) []string {
args := []string{
"-Fc",
"--no-password",
"-h", pg.Host,
"-p", strconv.Itoa(pg.Port),
"-U", pg.Username,
"-d", *pg.Database,
"--verbose",
}
compressionArgs := uc.getCompressionArgs(pg.Version)
return append(args, compressionArgs...)
}
func (uc *CreatePostgresqlBackupUsecase) getCompressionArgs(
version tools.PostgresqlVersion,
) []string {
if uc.isOlderPostgresVersion(version) {
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", version)
return []string{"-Z", strconv.Itoa(compressionLevel)}
}
uc.logger.Info("Using zstd compression level 5", "version", version)
return []string{fmt.Sprintf("--compress=zstd:%d", compressionLevel)}
}
func (uc *CreatePostgresqlBackupUsecase) isOlderPostgresVersion(
version tools.PostgresqlVersion,
) bool {
return version == tools.PostgresqlVersion12 ||
version == tools.PostgresqlVersion13 ||
version == tools.PostgresqlVersion14 ||
version == tools.PostgresqlVersion15
}
func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
parentCtx context.Context,
) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
go func() {
ticker := time.NewTicker(shutdownCheckInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if config.IsShouldShutdown() {
cancel()
return
}
}
}
}()
return ctx, cancel
}
func (uc *CreatePostgresqlBackupUsecase) setupPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
) (string, error) {
pgpassFile, err := uc.createTempPgpassFile(pgConfig, password)
if err != nil {
return "", fmt.Errorf("failed to create temporary .pgpass file: %w", err)
}
if pgpassFile == "" {
return "", fmt.Errorf("temporary .pgpass file was not created")
}
if info, err := os.Stat(pgpassFile); err == nil {
uc.logger.Info("Temporary .pgpass file created successfully",
"pgpassFile", pgpassFile,
"size", info.Size(),
"mode", info.Mode(),
)
} else {
return "", fmt.Errorf("failed to verify .pgpass file: %w", err)
}
return pgpassFile, nil
}
func (uc *CreatePostgresqlBackupUsecase) setupPgEnvironment(
cmd *exec.Cmd,
pgpassFile string,
shouldRequireSSL bool,
password string,
cpuCount int,
pgBin string,
) error {
cmd.Env = os.Environ()
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
uc.logger.Info("Setting up PostgreSQL environment",
"passwordLength", len(password),
"passwordEmpty", password == "",
"pgBin", pgBin,
"usingPgpassFile", true,
"parallelJobs", cpuCount,
)
cmd.Env = append(cmd.Env,
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT="+strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
)
if shouldRequireSSL {
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
uc.logger.Info("Using required SSL mode", "configuredHttps", shouldRequireSSL)
} else {
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
uc.logger.Info("Using preferred SSL mode", "configuredHttps", shouldRequireSSL)
}
cmd.Env = append(cmd.Env,
"PGSSLCERT=",
"PGSSLKEY=",
"PGSSLROOTCERT=",
"PGSSLCRL=",
)
if _, err := exec.LookPath(pgBin); err != nil {
return fmt.Errorf("PostgreSQL executable not found or not accessible: %s - %w", pgBin, err)
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
storageWriter io.WriteCloser,
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
metadata := BackupMetadata{}
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
metadata.Encryption = backups_config.BackupEncryptionNone
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
return storageWriter, nil, metadata, nil
}
salt, err := backup_encryption.GenerateSalt()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
}
nonce, err := backup_encryption.GenerateNonce()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
}
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
}
encWriter, err := backup_encryption.NewEncryptionWriter(
storageWriter,
masterKey,
backupID,
salt,
nonce,
)
if err != nil {
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
}
saltBase64 := base64.StdEncoding.EncodeToString(salt)
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
metadata.EncryptionSalt = &saltBase64
metadata.EncryptionIV = &nonceBase64
metadata.Encryption = backups_config.BackupEncryptionEncrypted
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
return encWriter, encWriter, metadata, nil
}
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
saveErrCh chan error,
) {
if encryptionWriter != nil {
go func() {
if closeErr := encryptionWriter.Close(); closeErr != nil {
uc.logger.Error(
"Failed to close encrypting writer during cancellation",
"error",
closeErr,
)
}
}()
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
}
<-saveErrCh
}
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
encryptionWriter *backup_encryption.EncryptionWriter,
storageWriter io.WriteCloser,
) error {
encryptionCloseErrCh := make(chan error, 1)
if encryptionWriter != nil {
go func() {
closeErr := encryptionWriter.Close()
if closeErr != nil {
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
}
encryptionCloseErrCh <- closeErr
}()
} else {
encryptionCloseErrCh <- nil
}
encryptionCloseErr := <-encryptionCloseErrCh
if encryptionCloseErr != nil {
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
}
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
}
if err := storageWriter.Close(); err != nil {
uc.logger.Error("Failed to close pipe writer", "error", err)
return err
}
return nil
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellation(ctx context.Context) error {
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
return nil
}
}
func (uc *CreatePostgresqlBackupUsecase) checkCancellationReason() error {
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
}
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
password string,
) error {
stderrStr := string(stderrOutput)
errorMsg := fmt.Sprintf("%s failed: %v stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
exitErr, ok := waitErr.(*exec.ExitError)
if !ok {
return errors.New(errorMsg)
}
exitCode := exitErr.ExitCode()
if exitCode == exitCodeGenericError && strings.TrimSpace(stderrStr) == "" {
return uc.handleExitCode1NoStderr(pgBin, args)
}
if exitCode == exitCodeAccessViolation {
return uc.handleAccessViolation(pgBin, stderrStr)
}
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
return uc.handleConnectionErrors(stderrStr, password)
}
return errors.New(errorMsg)
}
func (uc *CreatePostgresqlBackupUsecase) handleExitCode1NoStderr(
pgBin string,
args []string,
) error {
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
"pgBin", pgBin,
"args", args,
"env_vars", []string{
"PGCLIENTENCODING=UTF8",
"PGCONNECT_TIMEOUT=" + strconv.Itoa(pgConnectTimeout),
"LC_ALL=C.UTF-8",
"LANG=C.UTF-8",
"PGOPTIONS=--client-encoding=UTF8",
},
)
return fmt.Errorf(
"%s failed with exit status 1 but provided no error details. "+
"This often indicates: "+
"1) Connection timeout or refused connection, "+
"2) Authentication failure with incorrect credentials, "+
"3) Database does not exist, "+
"4) Network connectivity issues, "+
"5) PostgreSQL server not running. "+
"Command executed: %s %s",
filepath.Base(pgBin),
pgBin,
strings.Join(args, " "),
)
}
func (uc *CreatePostgresqlBackupUsecase) handleAccessViolation(
pgBin string,
stderrStr string,
) error {
uc.logger.Error("PostgreSQL tool crashed with access violation",
"pgBin", pgBin,
"exitCode", "0xC0000005",
)
return fmt.Errorf(
"%s crashed with access violation (0xC0000005). "+
"This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. "+
"stderr: %s",
filepath.Base(pgBin),
stderrStr,
)
}
func (uc *CreatePostgresqlBackupUsecase) handleConnectionErrors(
stderrStr string,
password string,
) error {
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
return fmt.Errorf(
"PostgreSQL connection rejected by server configuration (pg_hba.conf). "+
"The server may not allow connections from your IP address or may require different authentication settings. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "no password supplied") ||
containsIgnoreCase(stderrStr, "fe_sendauth") {
return fmt.Errorf(
"PostgreSQL authentication failed - no password supplied. "+
"PGPASSWORD environment variable may not be working correctly on this system. "+
"Password length: %d, Password empty: %v. "+
"Consider using a .pgpass file as an alternative. "+
"stderr: %s",
len(password),
password == "",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
return fmt.Errorf(
"PostgreSQL SSL connection failed. "+
"The server may require SSL encryption or have SSL configuration issues. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
return fmt.Errorf(
"PostgreSQL connection refused. "+
"Check if the server is running and accessible from your network. "+
"stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "authentication") ||
containsIgnoreCase(stderrStr, "password") {
return fmt.Errorf(
"PostgreSQL authentication failed. Check username and password. stderr: %s",
stderrStr,
)
}
if containsIgnoreCase(stderrStr, "timeout") {
return fmt.Errorf(
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
stderrStr,
)
}
return fmt.Errorf("PostgreSQL connection or authentication error. stderr: %s", stderrStr)
}
// createTempPgpassFile creates a temporary .pgpass file with the given password
func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
pgConfig *pgtypes.PostgresqlDatabase,
password string,
@@ -508,7 +726,6 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
password,
)
// it always create unique directory like /tmp/pgpass-1234567890
tempDir, err := os.MkdirTemp("", "pgpass")
if err != nil {
return "", fmt.Errorf("failed to create temporary directory: %w", err)
@@ -522,3 +739,7 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
return pgpassFile, nil
}
func containsIgnoreCase(str, substr string) bool {
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
}

View File

@@ -1,11 +1,15 @@
package usecases_postgresql
import (
"postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
logger.GetLogger(),
secrets.GetSecretKeyService(),
encryption.GetFieldEncryptor(),
}
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {

View File

@@ -0,0 +1,15 @@
package usecases_postgresql
import backups_config "postgresus-backend/internal/features/backups/config"
type EncryptionMetadata struct {
Salt string
IV string
Encryption backups_config.BackupEncryption
}
type BackupMetadata struct {
EncryptionSalt *string
EncryptionIV *string
Encryption backups_config.BackupEncryption
}

View File

@@ -2,7 +2,7 @@ package backups_config
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type BackupConfigController struct {
backupConfigService *BackupConfigService
userService *users.UserService
}
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
@@ -21,35 +20,29 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
// SaveBackupConfig
// @Summary Save backup configuration
// @Description Save or update backup configuration for a database
// @Description Save or update backup configuration for a database. Encryption can be set to NONE (no encryption) or ENCRYPTED (AES-256-GCM encryption).
// @Tags backup-configs
// @Accept json
// @Produce json
// @Param request body BackupConfig true "Backup configuration data"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 500
// @Param request body BackupConfig true "Backup configuration data (encryption field: NONE or ENCRYPTED)"
// @Success 200 {object} BackupConfig "Returns the saved backup configuration including encryption settings"
// @Failure 400 {object} map[string]string "Invalid encryption value or other validation errors"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 500 {object} map[string]string "Internal server error"
// @Router /backup-configs/save [post]
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var requestDTO BackupConfig
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
// make sure we rely on full .Storage object
requestDTO.StorageID = nil
@@ -64,40 +57,28 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
// GetBackupConfigByDbID
// @Summary Get backup configuration by database ID
// @Description Get backup configuration for a specific database
// @Description Get backup configuration for a specific database including encryption settings (NONE or ENCRYPTED)
// @Tags backup-configs
// @Produce json
// @Param id path string true "Database ID"
// @Success 200 {object} BackupConfig
// @Failure 400
// @Failure 401
// @Failure 404
// @Success 200 {object} BackupConfig "Returns backup configuration with encryption field"
// @Failure 400 {object} map[string]string "Invalid database ID"
// @Failure 401 {object} map[string]string "User not authenticated"
// @Failure 404 {object} map[string]string "Backup configuration not found"
// @Router /backup-configs/database/{id} [get]
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
_, err = c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
backupConfig, err := c.backupConfigService.GetBackupConfigByDbIdWithAuth(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "backup configuration not found"})
@@ -119,24 +100,18 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
// @Failure 500
// @Router /backup-configs/storage/{id}/is-using [get]
func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
isUsing, err := c.backupConfigService.IsStorageUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -0,0 +1,493 @@
package backups_config
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/period"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetBackupConfigController(),
)
return router
}
func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer cannot save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can save backup config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
}
var response BackupConfig
testResp := test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, 2, response.CpuCount)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get backup config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get backup config",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
var response BackupConfig
testResp := test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.BackupInterval)
} else {
assert.Contains(t, string(testResp.Body), "backup configuration not found")
}
})
}
}
func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, 1, response.CpuCount)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isStorageOwner bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "storage owner can check storage usage",
isStorageOwner: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-storage-owner cannot check storage usage",
isStorageOwner: false,
expectSuccess: false,
expectedStatusCode: http.StatusInternalServerError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
storageOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace(
"Test Workspace",
storageOwner,
router,
)
storage := createTestStorage(workspace.ID)
var testUserToken string
if tt.isStorageOwner {
testUserToken = storageOwner.Token
} else {
otherUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = otherUser.Token
}
if tt.expectSuccess {
var response map[string]bool
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
isUsing, exists := response["isUsing"]
assert.True(t, exists)
assert.False(t, isUsing)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "error")
}
// Cleanup
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionNone, response.Encryption)
}
func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionEncrypted,
}
var response BackupConfig
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+owner.Token,
request,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.Equal(t, BackupEncryptionEncrypted, response.Encryption)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
return storages.CreateTestStorage(workspaceID)
}

View File

@@ -3,7 +3,7 @@ package backups_config
import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -11,11 +11,11 @@ var backupConfigService = &BackupConfigService{
backupConfigRepository,
databases.GetDatabaseService(),
storages.GetStorageService(),
workspaces_services.GetWorkspaceService(),
nil,
}
var backupConfigController = &BackupConfigController{
backupConfigService,
users.GetUserService(),
}
func GetBackupConfigController() *BackupConfigController {

View File

@@ -6,3 +6,10 @@ const (
NotificationBackupFailed BackupNotificationType = "BACKUP_FAILED"
NotificationBackupSuccess BackupNotificationType = "BACKUP_SUCCESS"
)
type BackupEncryption string
const (
BackupEncryptionNone BackupEncryption = "NONE"
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
)

View File

@@ -31,6 +31,8 @@ type BackupConfig struct {
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
CpuCount int `json:"cpuCount" gorm:"type:int;not null"`
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
}
func (h *BackupConfig) TableName() string {
@@ -88,6 +90,11 @@ func (b *BackupConfig) Validate() error {
return errors.New("max failed tries count must be greater than 0")
}
if b.Encryption != "" && b.Encryption != BackupEncryptionNone &&
b.Encryption != BackupEncryptionEncrypted {
return errors.New("encryption must be NONE or ENCRYPTED")
}
return nil
}
@@ -103,5 +110,6 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
IsRetryIfFailed: b.IsRetryIfFailed,
MaxFailedTriesCount: b.MaxFailedTriesCount,
CpuCount: b.CpuCount,
Encryption: b.Encryption,
}
}

View File

@@ -1,10 +1,13 @@
package backups_config
import (
"errors"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/period"
"github.com/google/uuid"
@@ -14,6 +17,7 @@ type BackupConfigService struct {
backupConfigRepository *BackupConfigRepository
databaseService *databases.DatabaseService
storageService *storages.StorageService
workspaceService *workspaces_services.WorkspaceService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -32,11 +36,23 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
return nil, err
}
_, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
database, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot save backup config for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to modify backup configuration")
}
return s.SaveBackupConfig(backupConfig)
}
@@ -66,19 +82,6 @@ func (s *BackupConfigService) SaveBackupConfig(
}
}
if !backupConfig.IsBackupsEnabled && existingConfig.StorageID != nil {
if err := s.dbStorageChangeListener.OnBeforeBackupsStorageChange(
backupConfig.DatabaseID,
); err != nil {
return nil, err
}
// we clear storage for disabled backups to allow
// storage removal for unused storages
backupConfig.Storage = nil
backupConfig.StorageID = nil
}
return s.backupConfigRepository.Save(backupConfig)
}
@@ -144,6 +147,10 @@ func (s *BackupConfigService) OnDatabaseCopied(originalDatabaseID, newDatabaseID
}
}
func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) error {
return s.initializeDefaultConfig(databaseID)
}
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {
@@ -164,6 +171,7 @@ func (s *BackupConfigService) initializeDefaultConfig(
CpuCount: 1,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
Encryption: BackupEncryptionNone,
})
return err

View File

@@ -2,15 +2,18 @@ package databases
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type DatabaseController struct {
databaseService *DatabaseService
userService *users.UserService
databaseService *DatabaseService
userService *users_services.UserService
workspaceService *workspaces_services.WorkspaceService
}
func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
@@ -28,36 +31,35 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
// CreateDatabase
// @Summary Create a new database
// @Description Create a new database configuration
// @Description Create a new database configuration in a workspace
// @Tags databases
// @Accept json
// @Produce json
// @Param request body Database true "Database creation data"
// @Param request body Database true "Database creation data with workspaceId"
// @Success 201 {object} Database
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /databases/create [post]
func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
if request.WorkspaceID == nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
database, err := c.databaseService.CreateDatabase(user, &request)
database, err := c.databaseService.CreateDatabase(user, *request.WorkspaceID, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -79,24 +81,18 @@ func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
// @Failure 500
// @Router /databases/update [post]
func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.UpdateDatabase(user, &request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -116,24 +112,18 @@ func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id} [delete]
func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.DeleteDatabase(user, id); err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -153,24 +143,18 @@ func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
// @Failure 401
// @Router /databases/{id} [get]
func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
database, err := c.databaseService.GetDatabase(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -181,30 +165,38 @@ func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
}
// GetDatabases
// @Summary Get databases
// @Description Get all databases for the authenticated user
// @Summary Get databases by workspace
// @Description Get all databases for a specific workspace
// @Tags databases
// @Produce json
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Database
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /databases [get]
func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
databases, err := c.databaseService.GetDatabasesByUser(user)
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
databases, err := c.databaseService.GetDatabasesByWorkspace(user, workspaceID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -222,24 +214,18 @@ func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id}/test-connection [post]
func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.TestDatabaseConnection(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -259,27 +245,18 @@ func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
// @Failure 401
// @Router /databases/test-connection-direct [post]
func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
_, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
// Set user ID for validation purposes
request.UserID = user.ID
if err := c.databaseService.TestDatabaseConnectionDirect(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -300,24 +277,18 @@ func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
// @Failure 500
// @Router /databases/notifier/{id}/is-using [get]
func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
isUsing, err := c.databaseService.IsNotifierUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -339,24 +310,18 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id}/copy [post]
func (c *DatabaseController) CopyDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
copiedDatabase, err := c.databaseService.CopyDatabase(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

File diff suppressed because it is too large Load Diff

View File

@@ -5,9 +5,9 @@ import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/tools"
"regexp"
"slices"
"time"
"github.com/google/uuid"
@@ -59,11 +59,51 @@ func (p *PostgresqlDatabase) Validate() error {
return nil
}
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
func (p *PostgresqlDatabase) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
return testSingleDatabaseConnection(logger, ctx, p)
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
}
func (p *PostgresqlDatabase) HideSensitiveData() {
if p == nil {
return
}
p.Password = ""
}
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
p.Version = incoming.Version
p.Host = incoming.Host
p.Port = incoming.Port
p.Username = incoming.Username
p.Database = incoming.Database
p.IsHttps = incoming.IsHttps
if incoming.Password != "" {
p.Password = incoming.Password
}
}
func (p *PostgresqlDatabase) EncryptSensitiveFields(
databaseID uuid.UUID,
encryptor encryption.FieldEncryptor,
) error {
if p.Password != "" {
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
if err != nil {
return err
}
p.Password = encrypted
}
return nil
}
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
@@ -71,14 +111,22 @@ func testSingleDatabaseConnection(
logger *slog.Logger,
ctx context.Context,
postgresDb *PostgresqlDatabase,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
// For single database backup, we need to connect to the specific database
if postgresDb.Database == nil || *postgresDb.Database == "" {
return errors.New("database name is required for single database backup (pg_dump)")
}
// Decrypt password if needed
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
// Build connection string for the specific database
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
// Test connection
conn, err := pgx.Connect(ctx, connStr)
@@ -161,7 +209,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
}
// buildConnectionStringForDB builds connection string for specific database
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
sslMode := "disable"
if p.IsHttps {
sslMode = "require"
@@ -171,106 +219,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
p.Host,
p.Port,
p.Username,
p.Password,
password,
dbName,
sslMode,
)
}
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
if len(extensions) == 0 {
return nil
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, error) {
if encryptor == nil {
return password, nil
}
if p.Database == nil || *p.Database == "" {
return errors.New("database name is required for installing extensions")
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Build connection string for the specific database
connStr := buildConnectionStringForDB(p, *p.Database)
// Connect to database
conn, err := pgx.Connect(ctx, connStr)
if err != nil {
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
}
defer func() {
if closeErr := conn.Close(ctx); closeErr != nil {
fmt.Println("failed to close connection: %w", closeErr)
}
}()
// Check which extensions are already installed
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
if err != nil {
return fmt.Errorf("failed to check installed extensions: %w", err)
}
// Install missing extensions
for _, extension := range extensions {
if contains(installedExtensions, string(extension)) {
continue // Extension already installed
}
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
}
}
return nil
}
// getInstalledExtensions queries the database for currently installed extensions
func (p *PostgresqlDatabase) getInstalledExtensions(
ctx context.Context,
conn *pgx.Conn,
) ([]string, error) {
query := "SELECT extname FROM pg_extension"
rows, err := conn.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
}
defer rows.Close()
var extensions []string
for rows.Next() {
var extname string
if err := rows.Scan(&extname); err != nil {
return nil, fmt.Errorf("failed to scan extension name: %w", err)
}
extensions = append(extensions, extname)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
}
return extensions, nil
}
// installExtension installs a single PostgreSQL extension
func (p *PostgresqlDatabase) installExtension(
ctx context.Context,
conn *pgx.Conn,
extensionName string,
) error {
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
_, err := conn.Exec(ctx, query)
if err != nil {
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
}
return nil
}
// contains checks if a string slice contains a specific string
func contains(slice []string, item string) bool {
return slices.Contains(slice, item)
return encryptor.Decrypt(databaseID, password)
}

View File

@@ -1,8 +1,11 @@
package databases
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/users"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -15,11 +18,15 @@ var databaseService = &DatabaseService{
[]DatabaseCreationListener{},
[]DatabaseRemoveListener{},
[]DatabaseCopyListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var databaseController = &DatabaseController{
databaseService,
users.GetUserService(),
users_services.GetUserService(),
workspaces_services.GetWorkspaceService(),
}
func GetDatabaseService() *DatabaseService {
@@ -29,3 +36,7 @@ func GetDatabaseService() *DatabaseService {
func GetDatabaseController() *DatabaseController {
return databaseController
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
}

View File

@@ -2,6 +2,7 @@ package databases
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -11,7 +12,13 @@ type DatabaseValidator interface {
}
type DatabaseConnector interface {
TestConnection(logger *slog.Logger) error
TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error
HideSensitiveData()
}
type DatabaseCreationListener interface {

View File

@@ -5,16 +5,20 @@ import (
"log/slog"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
)
type Database struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;type:uuid;not null"`
Name string `json:"name" gorm:"column:name;type:text;not null"`
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
// WorkspaceID can be null when a database is created via restore operation
// outside the context of any workspace
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;type:uuid"`
Name string `json:"name" gorm:"column:name;type:text;not null"`
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
@@ -53,8 +57,35 @@ func (d *Database) ValidateUpdate(old, new Database) error {
return nil
}
func (d *Database) TestConnection(logger *slog.Logger) error {
return d.getSpecificDatabase().TestConnection(logger)
func (d *Database) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
) error {
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
if d.Postgresql != nil {
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
}
return nil
}
func (d *Database) Update(incoming *Database) {
d.Name = incoming.Name
d.Type = incoming.Type
d.Notifiers = incoming.Notifiers
switch d.Type {
case DatabaseTypePostgres:
if d.Postgresql != nil && incoming.Postgresql != nil {
d.Postgresql.Update(incoming.Postgresql)
}
}
}
func (d *Database) getSpecificDatabase() DatabaseConnector {

View File

@@ -92,14 +92,14 @@ func (r *DatabaseRepository) FindByID(id uuid.UUID) (*Database, error) {
return &database, nil
}
func (r *DatabaseRepository) FindByUserID(userID uuid.UUID) ([]*Database, error) {
func (r *DatabaseRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Database, error) {
var databases []*Database
if err := storage.
GetDb().
Preload("Postgresql").
Preload("Notifiers").
Where("user_id = ?", userID).
Where("workspace_id = ?", workspaceID).
Order("CASE WHEN health_status = 'UNAVAILABLE' THEN 1 WHEN health_status = 'AVAILABLE' THEN 2 WHEN health_status IS NULL THEN 3 ELSE 4 END, name ASC").
Find(&databases).Error; err != nil {
return nil, err

View File

@@ -2,11 +2,16 @@ package databases
import (
"errors"
"fmt"
"log/slog"
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
users_models "postgresus-backend/internal/features/users/models"
"time"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -19,6 +24,10 @@ type DatabaseService struct {
dbCreationListener []DatabaseCreationListener
dbRemoveListener []DatabaseRemoveListener
dbCopyListener []DatabaseCopyListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *DatabaseService) AddDbCreationListener(
@@ -41,15 +50,28 @@ func (s *DatabaseService) AddDbCopyListener(
func (s *DatabaseService) CreateDatabase(
user *users_models.User,
workspaceID uuid.UUID,
database *Database,
) (*Database, error) {
database.UserID = user.ID
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to create database in this workspace")
}
database.WorkspaceID = &workspaceID
if err := database.Validate(); err != nil {
return nil, err
}
database, err := s.dbRepository.Save(database)
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
database, err = s.dbRepository.Save(database)
if err != nil {
return nil, err
}
@@ -58,6 +80,12 @@ func (s *DatabaseService) CreateDatabase(
listener.OnDatabaseCreated(database.ID)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database created: %s", database.Name),
&user.ID,
&workspaceID,
)
return database, nil
}
@@ -74,24 +102,43 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if existingDatabase.UserID != user.ID {
return errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return errors.New("cannot update database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to update this database")
}
// Validate the update
if err := database.ValidateUpdate(*existingDatabase, *database); err != nil {
return err
}
if err := database.Validate(); err != nil {
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return err
}
_, err = s.dbRepository.Save(database)
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
}
_, err = s.dbRepository.Save(existingDatabase)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return nil
}
@@ -104,8 +151,16 @@ func (s *DatabaseService) DeleteDatabase(
return err
}
if existingDatabase.UserID != user.ID {
return errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return errors.New("cannot delete database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to delete this database")
}
for _, listener := range s.dbRemoveListener {
@@ -114,6 +169,12 @@ func (s *DatabaseService) DeleteDatabase(
}
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database deleted: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return s.dbRepository.Delete(id)
}
@@ -126,17 +187,44 @@ func (s *DatabaseService) GetDatabase(
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("you have not access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot access database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access this database")
}
database.HideSensitiveData()
return database, nil
}
func (s *DatabaseService) GetDatabasesByUser(
func (s *DatabaseService) GetDatabasesByWorkspace(
user *users_models.User,
workspaceID uuid.UUID,
) ([]*Database, error) {
return s.dbRepository.FindByUserID(user.ID)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access this workspace")
}
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return nil, err
}
for _, database := range databases {
database.HideSensitiveData()
}
return databases, nil
}
func (s *DatabaseService) IsNotifierUsing(
@@ -160,11 +248,19 @@ func (s *DatabaseService) TestDatabaseConnection(
return err
}
if database.UserID != user.ID {
return errors.New("you have not access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot test connection for database without workspace")
}
err = database.TestConnection(s.logger)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to test connection for this database")
}
err = database.TestConnection(s.logger, s.fieldEncryptor)
if err != nil {
lastSaveError := err.Error()
database.LastBackupErrorMessage = &lastSaveError
@@ -184,7 +280,31 @@ func (s *DatabaseService) TestDatabaseConnection(
func (s *DatabaseService) TestDatabaseConnectionDirect(
database *Database,
) error {
return database.TestConnection(s.logger)
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return err
}
if database.WorkspaceID != nil && existingDatabase.WorkspaceID != nil &&
*existingDatabase.WorkspaceID != *database.WorkspaceID {
return errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return err
}
usingDatabase = existingDatabase
} else {
usingDatabase = database
}
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
}
func (s *DatabaseService) GetDatabaseByID(
@@ -237,13 +357,21 @@ func (s *DatabaseService) CopyDatabase(
return nil, err
}
if existingDatabase.UserID != user.ID {
return nil, errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return nil, errors.New("cannot copy database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to copy this database")
}
newDatabase := &Database{
ID: uuid.Nil,
UserID: user.ID,
WorkspaceID: existingDatabase.WorkspaceID,
Name: existingDatabase.Name + " (Copy)",
Type: existingDatabase.Type,
Notifiers: existingDatabase.Notifiers,
@@ -286,6 +414,12 @@ func (s *DatabaseService) CopyDatabase(
listener.OnDatabaseCopied(databaseID, copiedDatabase.ID)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database copied: %s to %s", existingDatabase.Name, copiedDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return copiedDatabase, nil
}
@@ -306,3 +440,19 @@ func (s *DatabaseService) SetHealthStatus(
return nil
}
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return err
}
if len(databases) > 0 {
return fmt.Errorf(
"workspace contains %d databases that must be deleted",
len(databases),
)
}
return nil
}

View File

@@ -10,14 +10,14 @@ import (
)
func CreateTestDatabase(
userID uuid.UUID,
workspaceID uuid.UUID,
storage *storages.Storage,
notifier *notifiers.Notifier,
) *Database {
database := &Database{
UserID: userID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
WorkspaceID: &workspaceID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,

View File

@@ -0,0 +1,9 @@
package secrets
var secretKeyService = &SecretKeyService{
nil,
}
func GetSecretKeyService() *SecretKeyService {
return secretKeyService
}

View File

@@ -0,0 +1 @@
package secrets

View File

@@ -0,0 +1,73 @@
package secrets
import (
"errors"
"fmt"
"os"
"postgresus-backend/internal/config"
user_models "postgresus-backend/internal/features/users/models"
"postgresus-backend/internal/storage"
"github.com/google/uuid"
"gorm.io/gorm"
)
type SecretKeyService struct {
cachedKey *string
}
func (s *SecretKeyService) MigrateKeyFromDbToFileIfExist() error {
var secretKey user_models.SecretKey
err := storage.GetDb().First(&secretKey).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil
}
return fmt.Errorf("failed to check for secret key in database: %w", err)
}
if secretKey.Secret == "" {
return nil
}
secretKeyPath := config.GetEnv().SecretKeyPath
if err := os.WriteFile(secretKeyPath, []byte(secretKey.Secret), 0600); err != nil {
return fmt.Errorf("failed to write secret key to file: %w", err)
}
if err := storage.GetDb().Exec("DELETE FROM secret_keys").Error; err != nil {
return fmt.Errorf("failed to delete secret key from database: %w", err)
}
return nil
}
func (s *SecretKeyService) GetSecretKey() (string, error) {
if s.cachedKey != nil {
return *s.cachedKey, nil
}
secretKeyPath := config.GetEnv().SecretKeyPath
data, err := os.ReadFile(secretKeyPath)
if err != nil {
if os.IsNotExist(err) {
newKey := s.generateNewSecretKey()
if err := os.WriteFile(secretKeyPath, []byte(newKey), 0600); err != nil {
return "", fmt.Errorf("failed to write new secret key: %w", err)
}
s.cachedKey = &newKey
return newKey, nil
}
return "", fmt.Errorf("failed to read secret key file: %w", err)
}
key := string(data)
s.cachedKey = &key
return key, nil
}
func (s *SecretKeyService) generateNewSecretKey() string {
return uuid.New().String() + uuid.New().String()
}

View File

@@ -10,23 +10,34 @@ import (
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func Test_CheckPgHealthUseCase(t *testing.T) {
user := users.GetTestUser()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
// Create workspace directly via service
workspace, err := workspaces_testing.CreateTestWorkspaceDirect("Test Workspace", user.UserID)
if err != nil {
t.Fatalf("Failed to create workspace: %v", err)
}
defer storages.RemoveTestStorage(storage.ID)
defer notifiers.RemoveTestNotifier(notifier)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspaceDirect(workspace.ID)
}()
t.Run("Test_DbAttemptFailed_DbMarkedAsUnavailable", func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender
@@ -94,7 +105,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbNotMarkerdAsDownAfterFirstAttempt",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender
@@ -160,7 +171,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbMarkerdAsDownAfterThirdFailedAttempt",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Make sure DB is available
@@ -237,7 +248,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
)
t.Run("Test_UnavailableDbAttemptSucceed_DbMarkedAsAvailable", func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Make sure DB is unavailable
@@ -311,7 +322,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbHealthcheckExecutedFast_HealthcheckNotExecutedFasterThanInterval",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender

View File

@@ -2,7 +2,7 @@ package healthcheck_attempt
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"time"
"github.com/gin-gonic/gin"
@@ -11,7 +11,6 @@ import (
type HealthcheckAttemptController struct {
healthcheckAttemptService *HealthcheckAttemptService
userService *users.UserService
}
func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
@@ -31,9 +30,9 @@ func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /healthcheck-attempts/{databaseId} [get]
func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -43,7 +42,7 @@ func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
return
}
afterDate := time.Now().UTC()
afterDate := time.Now().UTC().Add(-7 * 24 * time.Hour)
if afterDateStr := ctx.Query("afterDate"); afterDateStr != "" {
parsedDate, err := time.Parse(time.RFC3339, afterDateStr)
if err != nil {

View File

@@ -0,0 +1,261 @@
package healthcheck_attempt
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetHealthcheckAttemptController(),
)
return router
}
func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get healthcheck attempts",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get healthcheck attempts",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
pastTime := time.Now().UTC().Add(-1 * time.Hour)
createTestHealthcheckAttemptWithTime(
database.ID,
databases.HealthStatusAvailable,
pastTime,
)
createTestHealthcheckAttemptWithTime(
database.ID,
databases.HealthStatusUnavailable,
pastTime.Add(-30*time.Minute),
)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
if tt.expectSuccess {
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
assert.GreaterOrEqual(t, len(response), 2)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "forbidden")
}
})
}
}
func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
oldTime := time.Now().UTC().Add(-2 * time.Hour)
recentTime := time.Now().UTC().Add(-30 * time.Minute)
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusAvailable, oldTime)
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusUnavailable, recentTime)
createTestHealthcheckAttempt(database.ID, databases.HealthStatusAvailable)
afterDate := time.Now().UTC().Add(-1 * time.Hour)
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/healthcheck-attempts/%s?afterDate=%s",
database.ID.String(),
afterDate.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, 2, len(response))
for _, attempt := range response {
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
}
}
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, 0, len(response))
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestHealthcheckAttempt(databaseID uuid.UUID, status databases.HealthStatus) {
createTestHealthcheckAttemptWithTime(databaseID, status, time.Now().UTC())
}
func createTestHealthcheckAttemptWithTime(
databaseID uuid.UUID,
status databases.HealthStatus,
createdAt time.Time,
) {
repo := GetHealthcheckAttemptRepository()
attempt := &HealthcheckAttempt{
ID: uuid.New(),
DatabaseID: databaseID,
Status: status,
CreatedAt: createdAt,
}
if err := repo.Create(attempt); err != nil {
panic("Failed to create test healthcheck attempt: " + err.Error())
}
}

View File

@@ -4,7 +4,7 @@ import (
"postgresus-backend/internal/features/databases"
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -12,6 +12,7 @@ var healthcheckAttemptRepository = &HealthcheckAttemptRepository{}
var healthcheckAttemptService = &HealthcheckAttemptService{
healthcheckAttemptRepository,
databases.GetDatabaseService(),
workspaces_services.GetWorkspaceService(),
}
var checkPgHealthUseCase = &CheckPgHealthUseCase{
@@ -27,7 +28,10 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,
users.GetUserService(),
}
func GetHealthcheckAttemptRepository() *HealthcheckAttemptRepository {
return healthcheckAttemptRepository
}
func GetHealthcheckAttemptService() *HealthcheckAttemptService {

View File

@@ -53,7 +53,7 @@ func (r *HealthcheckAttemptRepository) DeleteOlderThan(
Delete(&HealthcheckAttempt{}).Error
}
func (r *HealthcheckAttemptRepository) Insert(
func (r *HealthcheckAttemptRepository) Create(
attempt *HealthcheckAttempt,
) error {
if attempt.ID == uuid.Nil {
@@ -67,6 +67,12 @@ func (r *HealthcheckAttemptRepository) Insert(
return storage.GetDb().Create(attempt).Error
}
func (r *HealthcheckAttemptRepository) Insert(
attempt *HealthcheckAttempt,
) error {
return r.Create(attempt)
}
func (r *HealthcheckAttemptRepository) FindByDatabaseIDWithLimit(
databaseID uuid.UUID,
limit int,

View File

@@ -4,6 +4,7 @@ import (
"errors"
"postgresus-backend/internal/features/databases"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"time"
"github.com/google/uuid"
@@ -12,6 +13,7 @@ import (
type HealthcheckAttemptService struct {
healthcheckAttemptRepository *HealthcheckAttemptRepository
databaseService *databases.DatabaseService
workspaceService *workspaces_services.WorkspaceService
}
func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
@@ -24,7 +26,15 @@ func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
return nil, err
}
if database.UserID != user.ID {
if database.WorkspaceID == nil {
return nil, errors.New("cannot access healthcheck attempts for databases without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("forbidden")
}

View File

@@ -2,7 +2,7 @@ package healthcheck_config
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type HealthcheckConfigController struct {
healthcheckConfigService *HealthcheckConfigService
userService *users.UserService
}
func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
@@ -31,9 +30,9 @@ func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /healthcheck-config [post]
func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -65,9 +64,9 @@ func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
// @Failure 401
// @Router /healthcheck-config/{databaseId} [get]
func (c *HealthcheckConfigController) GetHealthcheckConfig(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}

View File

@@ -0,0 +1,328 @@
package healthcheck_config
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetHealthcheckConfigController(),
)
return router
}
func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer cannot save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can save healthcheck config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
request := HealthcheckConfigDTO{
DatabaseID: database.ID,
IsHealthcheckEnabled: true,
IsSentNotificationWhenUnavailable: true,
IntervalMinutes: 5,
AttemptsBeforeConcideredAsDown: 3,
StoreAttemptsDays: 7,
}
if tt.expectSuccess {
var response map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
&response,
)
assert.Contains(t, response["message"], "successfully")
} else {
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := HealthcheckConfigDTO{
DatabaseID: database.ID,
IsHealthcheckEnabled: true,
IsSentNotificationWhenUnavailable: true,
IntervalMinutes: 5,
AttemptsBeforeConcideredAsDown: 3,
StoreAttemptsDays: 7,
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get healthcheck config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get healthcheck config",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
if tt.expectSuccess {
var response HealthcheckConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsHealthcheckEnabled)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response HealthcheckConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsHealthcheckEnabled)
assert.True(t, response.IsSentNotificationWhenUnavailable)
assert.Equal(t, 1, response.IntervalMinutes)
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
assert.Equal(t, 7, response.StoreAttemptsDays)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}

View File

@@ -1,8 +1,9 @@
package healthcheck_config
import (
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -10,11 +11,12 @@ var healthcheckConfigRepository = &HealthcheckConfigRepository{}
var healthcheckConfigService = &HealthcheckConfigService{
databases.GetDatabaseService(),
healthcheckConfigRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
logger.GetLogger(),
}
var healthcheckConfigController = &HealthcheckConfigController{
healthcheckConfigService,
users.GetUserService(),
}
func GetHealthcheckConfigService() *HealthcheckConfigService {

View File

@@ -2,9 +2,12 @@ package healthcheck_config
import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/google/uuid"
)
@@ -12,6 +15,8 @@ import (
type HealthcheckConfigService struct {
databaseService *databases.DatabaseService
healthcheckConfigRepository *HealthcheckConfigRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
logger *slog.Logger
}
@@ -33,8 +38,16 @@ func (s *HealthcheckConfigService) Save(
return err
}
if database.UserID != user.ID {
return errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot modify healthcheck config for databases without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, &user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to modify healthcheck config")
}
healthcheckConfig := configDTO.ToDTO()
@@ -60,6 +73,12 @@ func (s *HealthcheckConfigService) Save(
}
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Healthcheck config updated for database '%s'", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}
@@ -72,8 +91,16 @@ func (s *HealthcheckConfigService) GetByDatabaseID(
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot access healthcheck config for databases without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to view healthcheck config")
}
config, err := s.healthcheckConfigRepository.GetByDatabaseID(database.ID)

View File

@@ -2,15 +2,16 @@ package notifiers
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type NotifierController struct {
notifierService *NotifierService
userService *users.UserService
notifierService *NotifierService
workspaceService *workspaces_services.WorkspaceService
}
func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,35 +30,40 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param notifier body Notifier true "Notifier data"
// @Param request body Notifier true "Notifier data with workspaceId"
// @Success 200 {object} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers [post]
func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var notifier Notifier
if err := ctx.ShouldBindJSON(&notifier); err != nil {
var request Notifier
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := notifier.Validate(); err != nil {
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.notifierService.SaveNotifier(user, &notifier); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, notifier)
ctx.JSON(http.StatusOK, request)
}
// GetNotifier
@@ -70,11 +76,12 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
// @Success 200 {object} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id} [get]
func (c *NotifierController) GetNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -86,6 +93,10 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
notifier, err := c.notifierService.GetNotifier(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -95,22 +106,41 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
// GetNotifiers
// @Summary Get all notifiers
// @Description Get all notifiers for the current user
// @Description Get all notifiers for a workspace
// @Tags notifiers
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers [get]
func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
notifiers, err := c.notifierService.GetNotifiers(user)
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -128,11 +158,12 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id} [delete]
func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -142,13 +173,11 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
return
}
notifier, err := c.notifierService.GetNotifier(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.notifierService.DeleteNotifier(user, notifier.ID); err != nil {
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -166,11 +195,12 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id}/test [post]
func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -181,6 +211,10 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
}
if err := c.notifierService.SendTestNotification(user, id); err != nil {
if err.Error() == "insufficient permissions to test notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -195,28 +229,44 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param notifier body Notifier true "Notifier data"
// @Param request body Notifier true "Notifier data with workspaceId"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/direct-test [post]
func (c *NotifierController) SendTestNotificationDirect(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var notifier Notifier
if err := ctx.ShouldBindJSON(&notifier); err != nil {
var request Notifier
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// For direct test, associate with the current user
notifier.UserID = user.ID
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.notifierService.SendTestNotificationToNotifier(&notifier); err != nil {
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !canView {
ctx.JSON(
http.StatusForbidden,
gin.H{"error": "insufficient permissions to test notifier in this workspace"},
)
return
}
if err := c.notifierService.SendTestNotificationToNotifier(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,9 @@
package notifiers
import (
"postgresus-backend/internal/features/users"
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"postgresus-backend/internal/util/logger"
)
@@ -9,10 +11,13 @@ var notifierRepository = &NotifierRepository{}
var notifierService = &NotifierService{
notifierRepository,
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
encryption.GetFieldEncryptor(),
}
var notifierController = &NotifierController{
notifierService,
users.GetUserService(),
workspaces_services.GetWorkspaceService(),
}
func GetNotifierController() *NotifierController {
@@ -22,3 +27,10 @@ func GetNotifierController() *NotifierController {
func GetNotifierService() *NotifierService {
return notifierService
}
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
}

View File

@@ -1,9 +1,21 @@
package notifiers
import "log/slog"
import (
"log/slog"
"postgresus-backend/internal/util/encryption"
)
type NotificationSender interface {
Send(logger *slog.Logger, heading string, message string) error
Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error
Validate() error
Validate(encryptor encryption.FieldEncryptor) error
HideSensitiveData()
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
}

View File

@@ -9,13 +9,14 @@ import (
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
type Notifier struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null;type:uuid;index"`
WorkspaceID uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;not null;type:uuid;index"`
Name string `json:"name" gorm:"column:name;not null;type:varchar(255)"`
NotifierType NotifierType `json:"notifierType" gorm:"column:notifier_type;not null;type:varchar(50)"`
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
return "notifiers"
}
func (n *Notifier) Validate() error {
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.Name == "" {
return errors.New("name is required")
}
return n.getSpecificNotifier().Validate()
return n.getSpecificNotifier().Validate(encryptor)
}
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
err := n.getSpecificNotifier().Send(logger, heading, message)
func (n *Notifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
if err != nil {
lastSendError := err.Error()
@@ -54,6 +60,46 @@ func (n *Notifier) Send(logger *slog.Logger, heading string, message string) err
return err
}
func (n *Notifier) HideSensitiveData() {
n.getSpecificNotifier().HideSensitiveData()
}
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
}
func (n *Notifier) Update(incoming *Notifier) {
n.Name = incoming.Name
n.NotifierType = incoming.NotifierType
switch n.NotifierType {
case NotifierTypeTelegram:
if n.TelegramNotifier != nil && incoming.TelegramNotifier != nil {
n.TelegramNotifier.Update(incoming.TelegramNotifier)
}
case NotifierTypeEmail:
if n.EmailNotifier != nil && incoming.EmailNotifier != nil {
n.EmailNotifier.Update(incoming.EmailNotifier)
}
case NotifierTypeWebhook:
if n.WebhookNotifier != nil && incoming.WebhookNotifier != nil {
n.WebhookNotifier.Update(incoming.WebhookNotifier)
}
case NotifierTypeSlack:
if n.SlackNotifier != nil && incoming.SlackNotifier != nil {
n.SlackNotifier.Update(incoming.SlackNotifier)
}
case NotifierTypeDiscord:
if n.DiscordNotifier != nil && incoming.DiscordNotifier != nil {
n.DiscordNotifier.Update(incoming.DiscordNotifier)
}
case NotifierTypeTeams:
if n.TeamsNotifier != nil && incoming.TeamsNotifier != nil {
n.TeamsNotifier.Update(incoming.TeamsNotifier)
}
}
}
func (n *Notifier) getSpecificNotifier() NotificationSender {
switch n.NotifierType {
case NotifierTypeTelegram:

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
return "discord_notifiers"
}
func (d *DiscordNotifier) Validate() error {
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
return nil
}
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (d *DiscordNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal Discord payload: %w", err)
}
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}
@@ -71,3 +82,24 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return nil
}
func (d *DiscordNotifier) HideSensitiveData() {
d.ChannelWebhookURL = ""
}
func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
if incoming.ChannelWebhookURL != "" {
d.ChannelWebhookURL = incoming.ChannelWebhookURL
}
}
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if d.ChannelWebhookURL != "" {
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
d.ChannelWebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net"
"net/smtp"
"postgresus-backend/internal/util/encryption"
"time"
"github.com/google/uuid"
@@ -27,13 +28,14 @@ type EmailNotifier struct {
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
}
func (e *EmailNotifier) TableName() string {
return "email_notifiers"
}
func (e *EmailNotifier) Validate() error {
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if e.TargetEmail == "" {
return errors.New("target email is required")
}
@@ -54,11 +56,29 @@ func (e *EmailNotifier) Validate() error {
return nil
}
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (e *EmailNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
// Decrypt SMTP password if provided
var smtpPassword string
if e.SMTPPassword != "" {
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
}
smtpPassword = decrypted
}
// Compose email
from := e.SMTPUser
from := e.From
if from == "" {
from = "noreply@" + e.SMTPHost
from = e.SMTPUser
if from == "" {
from = "noreply@" + e.SMTPHost
}
}
to := []string{e.TargetEmail}
@@ -72,15 +92,16 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
)
body := message
fromHeader := fmt.Sprintf("From: %s\r\n", from)
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
// Combine all parts of the email
emailContent := []byte(fromHeader + subject + mime + body)
emailContent := []byte(fromHeader + toHeader + subject + mime + body)
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
timeout := DefaultTimeout
// Determine if authentication is required
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
// Handle different port scenarios
if e.SMTPPort == ImplicitTLSPort {
@@ -111,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Set up authentication only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -174,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
// Authenticate only if credentials are provided
if isAuthRequired {
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
if err := client.Auth(auth); err != nil {
return fmt.Errorf("SMTP authentication failed: %w", err)
}
@@ -208,3 +229,30 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
return client.Quit()
}
}
func (e *EmailNotifier) HideSensitiveData() {
e.SMTPPassword = ""
}
func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.TargetEmail = incoming.TargetEmail
e.SMTPHost = incoming.SMTPHost
e.SMTPPort = incoming.SMTPPort
e.SMTPUser = incoming.SMTPUser
e.From = incoming.From
if incoming.SMTPPassword != "" {
e.SMTPPassword = incoming.SMTPPassword
}
}
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if e.SMTPPassword != "" {
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
if err != nil {
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
}
e.SMTPPassword = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"io"
"log/slog"
"net/http"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
"time"
@@ -23,7 +24,7 @@ type SlackNotifier struct {
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
func (s *SlackNotifier) Validate() error {
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if s.BotToken == "" {
return errors.New("bot token is required")
}
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
return nil
}
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
func (s *SlackNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
full := fmt.Sprintf("*%s*", heading)
if message != "" {
@@ -80,7 +90,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Authorization", "Bearer "+s.BotToken)
req.Header.Set("Authorization", "Bearer "+botToken)
resp, err := http.DefaultClient.Do(req)
if err != nil {
@@ -132,3 +142,26 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
return nil
}
}
func (s *SlackNotifier) HideSensitiveData() {
s.BotToken = ""
}
func (s *SlackNotifier) Update(incoming *SlackNotifier) {
s.TargetChatID = incoming.TargetChatID
if incoming.BotToken != "" {
s.BotToken = incoming.BotToken
}
}
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if s.BotToken != "" {
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
s.BotToken = encrypted
}
return nil
}

View File

@@ -8,6 +8,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
return "teams_notifiers"
}
func (n *TeamsNotifier) Validate() error {
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL == "" {
return errors.New("webhook_url is required")
}
u, err := url.Parse(n.WebhookURL)
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
u, err := url.Parse(webhookURL)
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
return errors.New("invalid webhook_url")
}
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
}
type cardAttachment struct {
ContentType string `json:"contentType"`
Content interface{} `json:"content"`
ContentType string `json:"contentType"`
Content any `json:"content"`
}
type payload struct {
@@ -43,11 +50,20 @@ type payload struct {
Attachments []cardAttachment `json:"attachments,omitempty"`
}
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
if err := n.Validate(); err != nil {
func (n *TeamsNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading, message string,
) error {
if err := n.Validate(encryptor); err != nil {
return err
}
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
card := map[string]any{
"type": "AdaptiveCard",
"version": "1.4",
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
}
body, _ := json.Marshal(p)
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
if err != nil {
return err
}
@@ -94,3 +110,24 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
return nil
}
func (n *TeamsNotifier) HideSensitiveData() {
n.WebhookURL = ""
}
func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
if incoming.WebhookURL != "" {
n.WebhookURL = incoming.WebhookURL
}
}
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if n.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
n.WebhookURL = encrypted
}
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"strconv"
"strings"
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
return "telegram_notifiers"
}
func (t *TelegramNotifier) Validate() error {
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.BotToken == "" {
return errors.New("bot token is required")
}
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
return nil
}
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *TelegramNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to decrypt bot token: %w", err)
}
fullMessage := heading
if message != "" {
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
}
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
data := url.Values{}
data.Set("chat_id", t.TargetChatID)
@@ -80,3 +91,27 @@ func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message str
return nil
}
func (t *TelegramNotifier) HideSensitiveData() {
t.BotToken = ""
}
func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
t.TargetChatID = incoming.TargetChatID
t.ThreadID = incoming.ThreadID
if incoming.BotToken != "" {
t.BotToken = incoming.BotToken
}
}
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.BotToken != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
if err != nil {
return fmt.Errorf("failed to encrypt bot token: %w", err)
}
t.BotToken = encrypted
}
return nil
}

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"net/http"
"net/url"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
return "webhook_notifiers"
}
func (t *WebhookNotifier) Validate() error {
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL == "" {
return errors.New("webhook URL is required")
}
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
return nil
}
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
func (t *WebhookNotifier) Send(
encryptor encryption.FieldEncryptor,
logger *slog.Logger,
heading string,
message string,
) error {
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
}
switch t.WebhookMethod {
case WebhookMethodGET:
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
t.WebhookURL,
webhookURL,
url.QueryEscape(heading),
url.QueryEscape(message),
)
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("failed to marshal webhook payload: %w", err)
}
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
if err != nil {
return fmt.Errorf("failed to send POST webhook: %w", err)
}
@@ -102,3 +113,22 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
}
func (t *WebhookNotifier) HideSensitiveData() {
}
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
t.WebhookURL = incoming.WebhookURL
t.WebhookMethod = incoming.WebhookMethod
}
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
if t.WebhookURL != "" {
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
if err != nil {
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
}
t.WebhookURL = encrypted
}
return nil
}

View File

@@ -143,7 +143,7 @@ func (r *NotifierRepository) FindByID(id uuid.UUID) (*Notifier, error) {
return &notifier, nil
}
func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error) {
func (r *NotifierRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Notifier, error) {
var notifiers []*Notifier
if err := storage.
@@ -154,7 +154,7 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
Preload("SlackNotifier").
Preload("DiscordNotifier").
Preload("TeamsNotifier").
Where("user_id = ?", userID).
Where("workspace_id = ?", workspaceID).
Order("name ASC").
Find(&notifiers).Error; err != nil {
return nil, err
@@ -165,7 +165,6 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
func (r *NotifierRepository) Delete(notifier *Notifier) error {
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
switch notifier.NotifierType {
case NotifierTypeTelegram:
if notifier.TelegramNotifier != nil {

View File

@@ -2,8 +2,13 @@ package notifiers
import (
"errors"
"fmt"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/encryption"
"github.com/google/uuid"
)
@@ -11,30 +16,77 @@ import (
type NotifierService struct {
notifierRepository *NotifierRepository
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
fieldEncryptor encryption.FieldEncryptor
}
func (s *NotifierService) SaveNotifier(
user *users_models.User,
workspaceID uuid.UUID,
notifier *Notifier,
) error {
if notifier.ID != uuid.Nil {
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
}
isUpdate := notifier.ID != uuid.Nil
if isUpdate {
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
if err != nil {
return err
}
if existingNotifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
if existingNotifier.WorkspaceID != workspaceID {
return errors.New("notifier does not belong to this workspace")
}
notifier.UserID = existingNotifier.UserID
} else {
notifier.UserID = user.ID
}
existingNotifier.Update(notifier)
_, err := s.notifierRepository.Save(notifier)
if err != nil {
return err
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
_, err = s.notifierRepository.Save(existingNotifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
} else {
notifier.WorkspaceID = workspaceID
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := notifier.Validate(s.fieldEncryptor); err != nil {
return err
}
_, err = s.notifierRepository.Save(notifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier created: %s", notifier.Name),
&user.ID,
&workspaceID,
)
}
return nil
@@ -49,11 +101,26 @@ func (s *NotifierService) DeleteNotifier(
return err
}
if notifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
canManage, err := s.workspaceService.CanUserManageDBs(notifier.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
}
return s.notifierRepository.Delete(notifier)
err = s.notifierRepository.Delete(notifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier deleted: %s", notifier.Name),
&user.ID,
&notifier.WorkspaceID,
)
return nil
}
func (s *NotifierService) GetNotifier(
@@ -65,17 +132,40 @@ func (s *NotifierService) GetNotifier(
return nil, err
}
if notifier.UserID != user.ID {
return nil, errors.New("you have not access to this notifier")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifier in this workspace")
}
notifier.HideSensitiveData()
return notifier, nil
}
func (s *NotifierService) GetNotifiers(
user *users_models.User,
workspaceID uuid.UUID,
) ([]*Notifier, error) {
return s.notifierRepository.FindByUserID(user.ID)
canView, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifiers in this workspace")
}
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return nil, err
}
for _, notifier := range notifiers {
notifier.HideSensitiveData()
}
return notifiers, nil
}
func (s *NotifierService) SendTestNotification(
@@ -87,11 +177,15 @@ func (s *NotifierService) SendTestNotification(
return err
}
if notifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
if err != nil {
return err
}
if !canView {
return errors.New("insufficient permissions to test notifier in this workspace")
}
err = notifier.Send(s.logger, "Test message", "This is a test message")
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
if err != nil {
return err
}
@@ -107,7 +201,38 @@ func (s *NotifierService) SendTestNotification(
func (s *NotifierService) SendTestNotificationToNotifier(
notifier *Notifier,
) error {
return notifier.Send(s.logger, "Test message", "This is a test message")
var usingNotifier *Notifier
if notifier.ID != uuid.Nil {
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
if err != nil {
return err
}
if existingNotifier.WorkspaceID != notifier.WorkspaceID {
return errors.New("notifier does not belong to this workspace")
}
existingNotifier.Update(notifier)
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = existingNotifier
} else {
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
return err
}
usingNotifier = notifier
}
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
}
func (s *NotifierService) SendNotification(
@@ -126,7 +251,7 @@ func (s *NotifierService) SendNotification(
return
}
err = notifiedFromDb.Send(s.logger, title, message)
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
if err != nil {
errMsg := err.Error()
notifiedFromDb.LastSendError = &errMsg
@@ -143,3 +268,18 @@ func (s *NotifierService) SendNotification(
s.logger.Error("Failed to save notifier", "error", err)
}
}
func (s *NotifierService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return fmt.Errorf("failed to get notifiers for workspace deletion: %w", err)
}
for _, notifier := range notifiers {
if err := s.notifierRepository.Delete(notifier); err != nil {
return fmt.Errorf("failed to delete notifier %s: %w", notifier.ID, err)
}
}
return nil
}

View File

@@ -6,9 +6,9 @@ import (
"github.com/google/uuid"
)
func CreateTestNotifier(userID uuid.UUID) *Notifier {
func CreateTestNotifier(workspaceID uuid.UUID) *Notifier {
notifier := &Notifier{
UserID: userID,
WorkspaceID: workspaceID,
Name: "test " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{

View File

@@ -2,7 +2,7 @@ package restores
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type RestoreController struct {
restoreService *RestoreService
userService *users.UserService
}
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,24 +28,18 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /restores/{backupId} [get]
func (c *RestoreController) GetRestores(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
backupID, err := uuid.Parse(ctx.Param("backupId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
restores, err := c.restoreService.GetRestores(user, backupID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -66,6 +59,12 @@ func (c *RestoreController) GetRestores(ctx *gin.Context) {
// @Failure 401
// @Router /restores/{backupId}/restore [post]
func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
backupID, err := uuid.Parse(ctx.Param("backupId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
@@ -78,18 +77,6 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.restoreService.RestoreBackupWithAuth(user, backupID, requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return

View File

@@ -0,0 +1,348 @@
package restores
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
local_storage "postgresus-backend/internal/features/storages/models/local"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
util_encryption "postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
var restores []*models.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&restores,
)
assert.NotNil(t, restores)
assert.Equal(t, 0, len(restores))
assert.NotNil(t, database)
}
func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+nonMember.Token,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
var restores []*models.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&restores,
)
assert.NotNil(t, restores)
}
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Database restored from backup") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for restore not found")
}
func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
backup := createTestBackup(database, owner)
return database, backup
}
func createTestDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,
Type: storages.StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
}
repo := &storages.StorageRepository{}
storage, err := repo.Save(storage)
if err != nil {
panic(err)
}
return storage
}
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
fieldEncryptor := util_encryption.GetFieldEncryptor()
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &backups.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}

View File

@@ -1,12 +1,13 @@
package restores
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/restores/usecases"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -19,10 +20,11 @@ var restoreService = &RestoreService{
usecases.GetRestoreBackupUsecase(),
databases.GetDatabaseService(),
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
}
var restoreController = &RestoreController{
restoreService,
users.GetUserService(),
}
var restoreBackgroundService = &RestoreBackgroundService{

View File

@@ -62,8 +62,6 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.
if err := storage.
GetDb().
Preload("Backup.Storage").
Preload("Backup.Database").
Preload("Backup").
Preload("Postgresql").
Where("status = ?", status).

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
@@ -12,6 +13,7 @@ import (
"postgresus-backend/internal/features/restores/usecases"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/tools"
"time"
@@ -26,6 +28,8 @@ type RestoreService struct {
restoreBackupUsecase *usecases.RestoreBackupUsecase
databaseService *databases.DatabaseService
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
@@ -58,8 +62,24 @@ func (s *RestoreService) GetRestores(
return nil, err
}
if backup.Database.UserID != user.ID {
return nil, errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot get restores for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access restores for this backup")
}
return s.restoreRepository.FindByBackupID(backupID)
@@ -75,8 +95,24 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if backup.Database.UserID != user.ID {
return errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot restore backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to restore this backup")
}
backupDatabase, err := s.databaseService.GetDatabase(user, backup.DatabaseID)
@@ -105,6 +141,16 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
}()
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database restored from backup %s for database: %s",
backupID.String(),
database.Name,
),
&user.ID,
database.WorkspaceID,
)
return nil
}
@@ -116,7 +162,12 @@ func (s *RestoreService) RestoreBackup(
return errors.New("backup is not completed")
}
if backup.Database.Type == databases.DatabaseTypePostgres {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.Type == databases.DatabaseTypePostgres {
if requestDTO.PostgresqlDatabase == nil {
return errors.New("postgresql database is required")
}
@@ -157,7 +208,7 @@ func (s *RestoreService) RestoreBackup(
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
backup.Database.ID,
database.ID,
)
if err != nil {
return err
@@ -168,6 +219,7 @@ func (s *RestoreService) RestoreBackup(
err = s.restoreBackupUsecase.Execute(
backupConfig,
restore,
database,
backup,
storage,
)

View File

@@ -1,11 +1,13 @@
package usecases_postgresql
import (
"postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/util/logger"
)
var restorePostgresqlBackupUsecase = &RestorePostgresqlBackupUsecase{
logger.GetLogger(),
secrets.GetSecretKeyService(),
}
func GetRestorePostgresqlBackupUsecase() *RestorePostgresqlBackupUsecase {

View File

@@ -2,6 +2,7 @@ package usecases_postgresql
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
@@ -15,11 +16,14 @@ import (
"postgresus-backend/internal/config"
"postgresus-backend/internal/features/backups/backups"
"postgresus-backend/internal/features/backups/backups/encryption"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
util_encryption "postgresus-backend/internal/util/encryption"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/tools"
@@ -27,16 +31,18 @@ import (
)
type RestorePostgresqlBackupUsecase struct {
logger *slog.Logger
logger *slog.Logger
secretKeyService *encryption_secrets.SecretKeyService
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
database *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
storage *storages.Storage,
) error {
if backup.Database.Type != databases.DatabaseTypePostgres {
if database.Type != databases.DatabaseTypePostgres {
return errors.New("database type not supported")
}
@@ -76,6 +82,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
}
return uc.restoreFromStorage(
database,
tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
@@ -92,6 +99,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
database *databases.Database,
pgBin string,
args []string,
password string,
@@ -164,7 +172,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// Add the temporary backup file as the last argument to pg_restore
args = append(args, tempBackupFile)
return uc.executePgRestore(ctx, pgBin, args, pgpassFile, pgConfig, backup)
return uc.executePgRestore(ctx, database, pgBin, args, pgpassFile, pgConfig)
}
// downloadBackupToTempFile downloads backup data from storage to a temporary file
@@ -199,18 +207,67 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
backup.ID,
"tempFile",
tempBackupFile,
"encrypted",
backup.Encryption == backups_config.BackupEncryptionEncrypted,
)
backupReader, err := storage.GetFile(backup.ID)
fieldEncryptor := util_encryption.GetFieldEncryptor()
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
}
defer func() {
if err := backupReader.Close(); err != nil {
if err := rawReader.Close(); err != nil {
uc.logger.Error("Failed to close backup reader", "error", err)
}
}()
// Create a reader that handles decryption if needed
var backupReader io.Reader = rawReader
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
// Validate encryption metadata
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
cleanupFunc()
return "", nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
}
// Get master key
masterKey, err := uc.secretKeyService.GetSecretKey()
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to get master key for decryption: %w", err)
}
// Decode salt and IV from base64
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption salt: %w", err)
}
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to decode encryption IV: %w", err)
}
// Create decryption reader
decryptReader, err := encryption.NewDecryptionReader(
rawReader,
masterKey,
backup.ID,
salt,
iv,
)
if err != nil {
cleanupFunc()
return "", nil, fmt.Errorf("failed to create decryption reader: %w", err)
}
backupReader = decryptReader
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
}
// Create temporary backup file
tempFile, err := os.Create(tempBackupFile)
if err != nil {
@@ -240,11 +297,11 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
// executePgRestore executes the pg_restore command with proper environment setup
func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
ctx context.Context,
database *databases.Database,
pgBin string,
args []string,
pgpassFile string,
pgConfig *pgtypes.PostgresqlDatabase,
backup *backups.Backup,
) error {
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL restore command", "command", cmd.String())
@@ -293,7 +350,7 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
return fmt.Errorf("restore cancelled due to shutdown")
}
return uc.handlePgRestoreError(waitErr, stderrOutput, pgBin, args, backup, pgConfig)
return uc.handlePgRestoreError(database, waitErr, stderrOutput, pgBin, args, pgConfig)
}
return nil
@@ -341,11 +398,11 @@ func (uc *RestorePostgresqlBackupUsecase) setupPgRestoreEnvironment(
// handlePgRestoreError processes and formats pg_restore errors
func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
database *databases.Database,
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
backup *backups.Backup,
pgConfig *pgtypes.PostgresqlDatabase,
) error {
// Enhanced error handling for PostgreSQL connection and restore issues
@@ -416,8 +473,8 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
)
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
backupDbName := "unknown"
if backup.Database != nil && backup.Database.Postgresql != nil && backup.Database.Postgresql.Database != nil {
backupDbName = *backup.Database.Postgresql.Database
if database.Postgresql != nil && database.Postgresql.Database != nil {
backupDbName = *database.Postgresql.Database
}
targetDbName := "unknown"

View File

@@ -17,11 +17,13 @@ type RestoreBackupUsecase struct {
func (uc *RestoreBackupUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore models.Restore,
database *databases.Database,
backup *backups.Backup,
storage *storages.Storage,
) error {
if restore.Backup.Database.Type == databases.DatabaseTypePostgres {
if database.Type == databases.DatabaseTypePostgres {
return uc.restorePostgresqlBackupUsecase.Execute(
database,
backupConfig,
restore,
backup,

View File

@@ -2,15 +2,16 @@ package storages
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type StorageController struct {
storageService *StorageService
userService *users.UserService
storageService *StorageService
workspaceService *workspaces_services.WorkspaceService
}
func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,35 +30,40 @@ func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param storage body Storage true "Storage data"
// @Param request body Storage true "Storage data with workspaceId"
// @Success 200 {object} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages [post]
func (c *StorageController) SaveStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var storage Storage
if err := ctx.ShouldBindJSON(&storage); err != nil {
var request Storage
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := storage.Validate(); err != nil {
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.storageService.SaveStorage(user, &storage); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, storage)
ctx.JSON(http.StatusOK, request)
}
// GetStorage
@@ -70,11 +76,12 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
// @Success 200 {object} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id} [get]
func (c *StorageController) GetStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -86,6 +93,10 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
storage, err := c.storageService.GetStorage(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -95,22 +106,41 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
// GetStorages
// @Summary Get all storages
// @Description Get all storages for the current user
// @Description Get all storages for a workspace
// @Tags storages
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages [get]
func (c *StorageController) GetStorages(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
storages, err := c.storageService.GetStorages(user)
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
storages, err := c.storageService.GetStorages(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view storages in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -128,11 +158,12 @@ func (c *StorageController) GetStorages(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id} [delete]
func (c *StorageController) DeleteStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -143,6 +174,10 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
}
if err := c.storageService.DeleteStorage(user, id); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -160,11 +195,12 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id}/test [post]
func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -175,6 +211,10 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
}
if err := c.storageService.TestStorageConnection(user, id); err != nil {
if err.Error() == "insufficient permissions to test storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -189,33 +229,44 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param storage body Storage true "Storage data"
// @Param request body Storage true "Storage data with workspaceId"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/direct-test [post]
func (c *StorageController) TestStorageConnectionDirect(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Storage
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
var storage Storage
if err := ctx.ShouldBindJSON(&storage); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// For direct test, associate with the current user
storage.UserID = user.ID
if err := storage.Validate(); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
if !canView {
ctx.JSON(
http.StatusForbidden,
gin.H{"error": "insufficient permissions to test storage in this workspace"},
)
return
}
if err := c.storageService.TestStorageConnectionDirect(&storage); err != nil {
if err := c.storageService.TestStorageConnectionDirect(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

View File

@@ -1,25 +1,46 @@
package storages
import (
"fmt"
"net/http"
local_storage "postgresus-backend/internal/features/storages/models/local"
"postgresus-backend/internal/features/users"
test_utils "postgresus-backend/internal/util/testing"
"strings"
"testing"
audit_logs "postgresus-backend/internal/features/audit_logs"
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
local_storage "postgresus-backend/internal/features/storages/models/local"
nas_storage "postgresus-backend/internal/features/storages/models/nas"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/encryption"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
verifyStorageData(t, storage, &savedStorage)
@@ -30,8 +51,8 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/storages/"+savedStorage.ID.String(),
user.Token,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
@@ -41,181 +62,788 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
// Verify storage is returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&storages,
)
assert.Contains(t, storages, savedStorage)
RemoveTestStorage(savedStorage.ID)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UpdateExistingStorage_UpdatedStorageReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
// Save initial storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Modify storage name
updatedName := "Updated Storage " + uuid.New().String()
savedStorage.Name = updatedName
// Update storage
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, savedStorage, http.StatusOK, &updatedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
// Verify updated data
assert.Equal(t, updatedName, updatedStorage.Name)
assert.Equal(t, savedStorage.ID, updatedStorage.ID)
// Verify through GET
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/storages/"+updatedStorage.ID.String(),
user.Token,
http.StatusOK,
&retrievedStorage,
)
verifyStorageData(t, &updatedStorage, &retrievedStorage)
// Verify storage is returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
)
assert.Contains(t, storages, updatedStorage)
RemoveTestStorage(updatedStorage.ID)
deleteStorage(t, router, updatedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteStorage_StorageNotReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
// Save initial storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Delete storage
test_utils.MakeDeleteRequest(
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusOK,
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
// Try to get deleted storage, should return error
response := test_utils.MakeGetRequest(
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusBadRequest,
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusBadRequest,
)
assert.Contains(t, string(response.Body), "error")
// Verify storage is not returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
)
assert.NotContains(t, storages, savedStorage)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_TestDirectStorageConnection_ConnectionEstablished(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
response := test_utils.MakePostRequest(
t, router, "/api/v1/storages/direct-test", user.Token, storage, http.StatusOK,
t, router, "/api/v1/storages/direct-test", "Bearer "+owner.Token, *storage, http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_TestExistingStorageConnection_ConnectionEstablished(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Test connection to existing storage
response := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages/"+savedStorage.ID.String()+"/test",
user.Token,
fmt.Sprintf("/api/v1/storages/%s/test", savedStorage.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
RemoveTestStorage(savedStorage.ID)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CallAllMethodsWithoutAuth_UnauthorizedErrorReturned(t *testing.T) {
func Test_ViewerCanViewStorages_ButCannotModify(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(uuid.New())
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
viewer,
users_enums.WorkspaceRoleViewer,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Test endpoints without auth
endpoints := []struct {
method string
url string
body interface{}
}{
{"GET", "/api/v1/storages", nil},
{"GET", "/api/v1/storages/" + uuid.New().String(), nil},
{"POST", "/api/v1/storages", storage},
{"DELETE", "/api/v1/storages/" + uuid.New().String(), nil},
{"POST", "/api/v1/storages/" + uuid.New().String() + "/test", nil},
{"POST", "/api/v1/storages/direct-test", storage},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
for _, endpoint := range endpoints {
testUnauthorizedEndpoint(t, router, endpoint.method, endpoint.url, endpoint.body)
}
// Viewer can GET storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+viewer.Token,
http.StatusOK,
&storages,
)
assert.Len(t, storages, 1)
// Viewer cannot CREATE storage
newStorage := createNewStorage(workspace.ID)
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, *newStorage, http.StatusForbidden,
)
// Viewer cannot UPDATE storage
savedStorage.Name = "Updated by viewer"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, savedStorage, http.StatusForbidden,
)
// Viewer cannot DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+viewer.Token,
http.StatusForbidden,
)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func testUnauthorizedEndpoint(
t *testing.T,
router *gin.Engine,
method, url string,
body interface{},
) {
test_utils.MakeRequest(t, router, test_utils.RequestOptions{
Method: method,
URL: url,
Body: body,
ExpectedStatus: http.StatusUnauthorized,
})
func Test_MemberCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Member can CREATE storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
*storage,
http.StatusOK,
&savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
// Member can UPDATE storage
savedStorage.Name = "Updated by member"
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
assert.Equal(t, "Updated by member", updatedStorage.Name)
// Member can DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_AdminCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
admin,
users_enums.WorkspaceRoleAdmin,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Admin can CREATE, UPDATE, DELETE
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*storage,
http.StatusOK,
&savedStorage,
)
savedStorage.Name = "Updated by admin"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+admin.Token, savedStorage, http.StatusOK,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Outsider cannot GET storages
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
// Outsider cannot CREATE storage
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+outsider.Token, *storage, http.StatusForbidden,
)
// Outsider cannot UPDATE storage
test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+outsider.Token,
savedStorage,
http.StatusForbidden,
)
// Outsider cannot DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *testing.T) {
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage1 := createNewStorage(workspace1.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner1.Token,
*storage1,
http.StatusOK,
&savedStorage,
)
// Try to access workspace1's storage with owner2 from workspace2
response := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner2.Token,
http.StatusForbidden,
)
assert.Contains(t, string(response.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, workspace1.ID, owner1.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
"S3AccessKey should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
"S3SecretKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
assert.NoError(t, err)
assert.Equal(t, "original-access-key", accessKey)
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
assert.NoError(t, err)
assert.Equal(t, "original-secret-key", secretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
{
name: "NAS Storage",
storageType: StorageTypeNAS,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Test NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas.example.com",
Port: 445,
Share: "backups",
Username: "testuser",
Password: "original-password",
UseSSL: false,
Domain: "WORKGROUP",
Path: "/test",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeNAS,
Name: "Updated NAS Storage",
NASStorage: &nas_storage.NASStorage{
Host: "nas2.example.com",
Port: 445,
Share: "backups2",
Username: "testuser2",
Password: "",
UseSSL: true,
Domain: "WORKGROUP2",
Path: "/test2",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
"Password should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
assert.NoError(t, err)
assert.Equal(t, "original-password", password)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.NASStorage.Password)
},
},
{
name: "Azure Blob Storage (Connection String)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "original-connection-string",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob Storage",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
ConnectionString: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
"ConnectionString should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
connectionString, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.ConnectionString,
)
assert.NoError(t, err)
assert.Equal(t, "original-connection-string", connectionString)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Azure Blob Storage (Account Key)",
storageType: StorageTypeAzureBlob,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Test Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "testaccount",
AccountKey: "original-account-key",
ContainerName: "test-container",
Endpoint: "",
Prefix: "backups/",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeAzureBlob,
Name: "Updated Azure Blob with Account Key",
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
AccountName: "updatedaccount",
AccountKey: "",
ContainerName: "updated-container",
Endpoint: "https://custom.blob.core.windows.net",
Prefix: "backups2/",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
"AccountKey should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
accountKey, err := encryptor.Decrypt(
storage.ID,
storage.AzureBlobStorage.AccountKey,
)
assert.NoError(t, err)
assert.Equal(t, "original-account-key", accountKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
},
},
{
name: "Google Drive Storage",
storageType: StorageTypeGoogleDrive,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Test Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "original-client-id",
ClientSecret: "original-client-secret",
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeGoogleDrive,
Name: "Updated Google Drive Storage",
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
ClientID: "updated-client-id",
ClientSecret: "",
TokenJSON: "",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
"ClientSecret should be encrypted with 'enc:' prefix")
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
"TokenJSON should be encrypted with 'enc:' prefix")
encryptor := encryption.GetFieldEncryptor()
clientSecret, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.ClientSecret,
)
assert.NoError(t, err)
assert.Equal(t, "original-client-secret", clientSecret)
tokenJSON, err := encryptor.Decrypt(
storage.ID,
storage.GoogleDriveStorage.TokenJSON,
)
assert.NoError(t, err)
assert.Equal(
t,
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
tokenJSON,
)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Verify sensitive data is encrypted in repository after creation
repository := &StorageRepository{}
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
tc.verifySensitiveData(t, storageFromDBAfterCreate)
// Phase 3: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
controller := GetStorageController()
v1 := router.Group("/api/v1")
controller.RegisterRoutes(v1)
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetStorageController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewStorage(userID uuid.UUID) *Storage {
func createNewStorage(workspaceID uuid.UUID) *Storage {
return &Storage{
UserID: userID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
@@ -225,5 +853,20 @@ func createNewStorage(userID uuid.UUID) *Storage {
func verifyStorageData(t *testing.T, expected *Storage, actual *Storage) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Type, actual.Type)
assert.Equal(t, expected.UserID, actual.UserID)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteStorage(
t *testing.T,
router *gin.Engine,
storageID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", storageID.String()),
"Bearer "+token,
http.StatusOK,
)
}

Some files were not shown because too many files have changed in this diff Show More