mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95c833b619 | ||
|
|
878fad5747 | ||
|
|
6ff3096695 | ||
|
|
b4b514c2d5 | ||
|
|
da0fec6624 | ||
|
|
408675023a | ||
|
|
0bc93389cc | ||
|
|
c8e6aea6e1 | ||
|
|
981ad21471 | ||
|
|
177a9c782c | ||
|
|
069d6bc8fe | ||
|
|
242d5543d4 | ||
|
|
02c735bc5a | ||
|
|
793b575146 | ||
|
|
a6e84b45f2 | ||
|
|
a941fbd093 | ||
|
|
4492ba41f5 | ||
|
|
3a5ac4b479 | ||
|
|
77aaabeaa1 | ||
|
|
01911dbf72 | ||
|
|
1a16f27a5d | ||
|
|
778db71625 | ||
|
|
45fc9a7fff | ||
|
|
7f5e786261 | ||
|
|
9b066bcb8a | ||
|
|
9ea795b48f | ||
|
|
a809dc8a9c | ||
|
|
bd053b51a3 | ||
|
|
431e9861f4 | ||
|
|
de1fd4c4da | ||
|
|
df55fd17d5 |
12
.github/workflows/ci-release.yml
vendored
12
.github/workflows/ci-release.yml
vendored
@@ -127,6 +127,7 @@ jobs:
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
TEST_GOOGLE_DRIVE_TOKEN_JSON=${{ secrets.TEST_GOOGLE_DRIVE_TOKEN_JSON }}
|
||||
# testing DBs
|
||||
TEST_POSTGRES_12_PORT=5000
|
||||
TEST_POSTGRES_13_PORT=5001
|
||||
TEST_POSTGRES_14_PORT=5002
|
||||
TEST_POSTGRES_15_PORT=5003
|
||||
@@ -136,8 +137,13 @@ jobs:
|
||||
# testing S3
|
||||
TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing Azure Blob
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
|
||||
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -151,6 +157,7 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
@@ -160,6 +167,9 @@ jobs:
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
|
||||
- name: Create data and temp directories
|
||||
run: |
|
||||
# Create directories that are used for backups and restore
|
||||
@@ -182,7 +192,7 @@ jobs:
|
||||
- name: Run Go tests
|
||||
run: |
|
||||
cd backend
|
||||
go test ./internal/...
|
||||
go test -p=1 -count=1 -failfast ./internal/...
|
||||
|
||||
- name: Stop test containers
|
||||
if: always()
|
||||
|
||||
@@ -77,7 +77,7 @@ ENV APP_VERSION=$APP_VERSION
|
||||
# Set production mode for Docker containers
|
||||
ENV ENV_MODE=production
|
||||
|
||||
# Install PostgreSQL server and client tools (versions 13-17)
|
||||
# Install PostgreSQL server and client tools (versions 12-18)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
wget ca-certificates gnupg lsb-release sudo gosu && \
|
||||
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
|
||||
@@ -85,7 +85,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
> /etc/apt/sources.list.d/pgdg.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 postgresql-18 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
|
||||
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
|
||||
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -187,7 +187,7 @@
|
||||
same "license" line as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 LogBull
|
||||
Copyright 2025 Postgresus
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
36
README.md
36
README.md
@@ -9,7 +9,7 @@
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
|
||||
@@ -40,13 +40,13 @@
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗄️ **Multiple Storage Destinations**
|
||||
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
|
||||
- **Secure**: All data stays under your control
|
||||
|
||||
### 📱 **Smart Notifications**
|
||||
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
|
||||
|
||||
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
|
||||
- **Real-time updates**: Success and failure notifications
|
||||
@@ -54,17 +54,31 @@
|
||||
|
||||
### 🐘 **PostgreSQL Support**
|
||||
|
||||
- **Multiple versions**: PostgreSQL 13, 14, 15, 16, 17 and 18
|
||||
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
|
||||
- **SSL support**: Secure connections available
|
||||
- **Easy restoration**: One-click restore from any backup
|
||||
|
||||
### 🔒 **Backup Encryption** <a href="https://postgresus.com/encryption">(docs)</a>
|
||||
|
||||
- **AES-256-GCM encryption**: Enterprise-grade protection for backup files
|
||||
- **Zero-trust storage**: Encrypted backups are useless so you can keep in shared storages like S3, Azure Blob Storage, etc.
|
||||
- **Optionality**: Encrypted backups are optional and can be enabled or disabled if you wish
|
||||
- **Download unencrypted**: You can still download unencrypted backups via the 'Download' button to use them in `pg_restore` or other tools.
|
||||
|
||||
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
- **Access management**: Control who can view or manage specific databases with role-based permissions
|
||||
- **Audit logs**: Track all system activities and changes made by users
|
||||
- **User roles**: Assign viewer, member, admin or owner roles within workspaces
|
||||
|
||||
### 🐳 **Self-Hosted & Secure**
|
||||
|
||||
- **Docker-based**: Easy deployment and management
|
||||
- **Privacy-first**: All your data stays on your infrastructure
|
||||
- **Open source**: Apache 2.0 licensed, inspect every line of code
|
||||
|
||||
### 📦 Installation
|
||||
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
|
||||
|
||||
You have three ways to install Postgresus:
|
||||
|
||||
@@ -118,8 +132,6 @@ This single command will:
|
||||
Create a `docker-compose.yml` file with the following configuration:
|
||||
|
||||
```yaml
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
postgresus:
|
||||
container_name: postgresus
|
||||
@@ -149,14 +161,16 @@ docker compose up -d
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting Admin Password
|
||||
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
|
||||
|
||||
If you need to reset the admin password, you can use the built-in password reset command:
|
||||
If you need to reset the password, you can use the built-in password reset command:
|
||||
|
||||
```bash
|
||||
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123"
|
||||
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123" --email="admin"
|
||||
```
|
||||
|
||||
Replace `admin` with the actual email address of the user whose password you want to reset.
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
@@ -167,4 +181,4 @@ This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENS
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read [contributing guide](contribute/readme.md) for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
Contributions are welcome! Read <a href="https://postgresus.com/contributing">contributing guide</a> for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
1607
assets/dashboard.svg
1607
assets/dashboard.svg
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 791 KiB After Width: | Height: | Size: 913 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 34 KiB |
@@ -1,14 +1,16 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and *gin.Context for all routes.
|
||||
Example:
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
@@ -17,24 +19,26 @@ func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http"
|
||||
|
||||
user_models "logbull/internal/features/users/models"
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
@@ -52,29 +56,30 @@ func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
@@ -94,34 +99,35 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "logbull/internal/features/users/models"
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -117,9 +119,11 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
@@ -128,12 +132,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "logbull/internal/features/users/enums"
|
||||
users_middleware "logbull/internal/features/users/middleware"
|
||||
users_services "logbull/internal/features/users/services"
|
||||
users_testing "logbull/internal/features/users/testing"
|
||||
"logbull/internal/storage"
|
||||
test_utils "logbull/internal/util/testing"
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
users_middleware "postgresus/internal/features/users/middleware"
|
||||
users_services "postgresus/internal/features/users/services"
|
||||
users_testing "postgresus/internal/features/users/testing"
|
||||
"postgresus/internal/storage"
|
||||
test_utils "postgresus/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -256,14 +260,16 @@ func createRouter() *gin.Engine {
|
||||
return router
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "logbull/internal/features/users/services"
|
||||
"logbull/internal/util/logger"
|
||||
users_services "postgresus/internal/features/users/services"
|
||||
"postgresus/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
@@ -289,9 +295,11 @@ func SetupDependencies() {
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
@@ -309,9 +317,11 @@ type GetAuditLogsResponse struct {
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
@@ -332,13 +342,15 @@ func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"logbull/internal/storage"
|
||||
"postgresus/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -429,9 +441,11 @@ func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
return count, err
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
@@ -439,8 +453,8 @@ import (
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "logbull/internal/features/users/enums"
|
||||
user_models "logbull/internal/features/users/models"
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -560,17 +574,19 @@ func (s *AuditLogService) GetProjectAuditLogs(
|
||||
}, nil
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
``````
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "logbull/internal/features/users/enums"
|
||||
users_testing "logbull/internal/features/users/testing"
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
users_testing "postgresus/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -652,4 +668,4 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
``````
|
||||
```
|
||||
|
||||
@@ -17,6 +17,7 @@ TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=
|
||||
TEST_GOOGLE_DRIVE_TOKEN_JSON="{\"access_token\":\"ya29..."
|
||||
# testing DBs
|
||||
TEST_POSTGRES_12_PORT=5000
|
||||
TEST_POSTGRES_13_PORT=5001
|
||||
TEST_POSTGRES_14_PORT=5002
|
||||
TEST_POSTGRES_15_PORT=5003
|
||||
@@ -27,4 +28,9 @@ TEST_POSTGRES_18_PORT=5006
|
||||
TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=7006
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=
|
||||
TEST_TELEGRAM_CHAT_ID=
|
||||
# testing Azure Blob Storage
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -12,4 +12,5 @@ swagger/swagger.yaml
|
||||
postgresus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
temp/
|
||||
cmd.exe
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/downdetect"
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
@@ -24,7 +24,10 @@ import (
|
||||
"postgresus-backend/internal/features/restores"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
system_healthcheck "postgresus-backend/internal/features/system/healthcheck"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_controllers "postgresus-backend/internal/features/users/controllers"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
env_utils "postgresus-backend/internal/util/env"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
@@ -61,13 +64,14 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Handle password reset if flag is provided
|
||||
newPassword := flag.String("new-password", "", "Set a new password for the user")
|
||||
flag.Parse()
|
||||
if *newPassword != "" {
|
||||
resetPassword(*newPassword, log)
|
||||
err = users_services.GetUserService().CreateInitialAdmin()
|
||||
if err != nil {
|
||||
log.Error("Failed to create initial admin", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
handlePasswordReset(log)
|
||||
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -91,11 +95,33 @@ func main() {
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
}
|
||||
|
||||
func resetPassword(newPassword string, log *slog.Logger) {
|
||||
func handlePasswordReset(log *slog.Logger) {
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
newPassword := flag.String("new-password", "", "Set a new password for the user")
|
||||
email := flag.String("email", "", "Email of the user to reset password")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *newPassword == "" {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Found reset password command - reseting password...")
|
||||
|
||||
if *email == "" {
|
||||
log.Info("No email provided, please provide an email via --email=\"some@email.com\" flag")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
resetPassword(*email, *newPassword, log)
|
||||
}
|
||||
|
||||
func resetPassword(email string, newPassword string, log *slog.Logger) {
|
||||
log.Info("Resetting password...")
|
||||
|
||||
userService := users.GetUserService()
|
||||
err := userService.ChangePassword(newPassword)
|
||||
userService := users_services.GetUserService()
|
||||
err := userService.ChangeUserPasswordByEmail(email, newPassword)
|
||||
if err != nil {
|
||||
log.Error("Failed to reset password", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -146,37 +172,44 @@ func setUpRoutes(r *gin.Engine) {
|
||||
// Mount Swagger UI
|
||||
v1.GET("/docs/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
downdetectContoller := downdetect.GetDowndetectController()
|
||||
userController := users.GetUserController()
|
||||
notifierController := notifiers.GetNotifierController()
|
||||
storageController := storages.GetStorageController()
|
||||
databaseController := databases.GetDatabaseController()
|
||||
backupController := backups.GetBackupController()
|
||||
restoreController := restores.GetRestoreController()
|
||||
healthcheckController := system_healthcheck.GetHealthcheckController()
|
||||
healthcheckConfigController := healthcheck_config.GetHealthcheckConfigController()
|
||||
healthcheckAttemptController := healthcheck_attempt.GetHealthcheckAttemptController()
|
||||
diskController := disk.GetDiskController()
|
||||
backupConfigController := backups_config.GetBackupConfigController()
|
||||
|
||||
downdetectContoller.RegisterRoutes(v1)
|
||||
// Public routes (only user auth routes and healthcheck should be public)
|
||||
userController := users_controllers.GetUserController()
|
||||
userController.RegisterRoutes(v1)
|
||||
notifierController.RegisterRoutes(v1)
|
||||
storageController.RegisterRoutes(v1)
|
||||
databaseController.RegisterRoutes(v1)
|
||||
backupController.RegisterRoutes(v1)
|
||||
restoreController.RegisterRoutes(v1)
|
||||
healthcheckController.RegisterRoutes(v1)
|
||||
diskController.RegisterRoutes(v1)
|
||||
healthcheckConfigController.RegisterRoutes(v1)
|
||||
healthcheckAttemptController.RegisterRoutes(v1)
|
||||
backupConfigController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
authMiddleware := users_middleware.AuthMiddleware(userService)
|
||||
|
||||
// Protected routes
|
||||
protected := v1.Group("")
|
||||
protected.Use(authMiddleware)
|
||||
|
||||
userController.RegisterProtectedRoutes(protected)
|
||||
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected)
|
||||
workspaces_controllers.GetMembershipController().RegisterRoutes(protected)
|
||||
disk.GetDiskController().RegisterRoutes(protected)
|
||||
notifiers.GetNotifierController().RegisterRoutes(protected)
|
||||
storages.GetStorageController().RegisterRoutes(protected)
|
||||
databases.GetDatabaseController().RegisterRoutes(protected)
|
||||
backups.GetBackupController().RegisterRoutes(protected)
|
||||
restores.GetRestoreController().RegisterRoutes(protected)
|
||||
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
||||
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
||||
backups_config.GetBackupConfigController().RegisterRoutes(protected)
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
databases.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
|
||||
@@ -31,7 +31,26 @@ services:
|
||||
container_name: test-minio
|
||||
command: server /data --console-address ":9001"
|
||||
|
||||
# Test Azurite container
|
||||
test-azurite:
|
||||
image: mcr.microsoft.com/azure-storage/azurite
|
||||
ports:
|
||||
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
|
||||
container_name: test-azurite
|
||||
command: azurite-blob --blobHost 0.0.0.0
|
||||
|
||||
# Test PostgreSQL containers
|
||||
test-postgres-12:
|
||||
image: postgres:12
|
||||
ports:
|
||||
- "${TEST_POSTGRES_12_PORT}:5432"
|
||||
environment:
|
||||
- POSTGRES_DB=testdb
|
||||
- POSTGRES_USER=testuser
|
||||
- POSTGRES_PASSWORD=testpassword
|
||||
container_name: test-postgres-12
|
||||
shm_size: 1gb
|
||||
|
||||
test-postgres-13:
|
||||
image: postgres:13
|
||||
ports:
|
||||
|
||||
@@ -3,6 +3,8 @@ module postgresus-backend
|
||||
go 1.23.3
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -15,16 +17,18 @@ require (
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/minio/minio-go/v7 v7.0.92
|
||||
github.com/shirou/gopsutil/v4 v4.25.5
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/time v0.12.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
|
||||
require github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.16.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
@@ -99,12 +103,12 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
google.golang.org/api v0.239.0
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
|
||||
@@ -6,6 +6,18 @@ cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeO
|
||||
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 h1:KpMC6LFL7mqpExyMC9jVOYRiVhLmamjeZfRsUpB7l4s=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -80,6 +92,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -131,6 +145,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
@@ -159,6 +175,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY=
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
@@ -180,8 +198,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
|
||||
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
|
||||
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
|
||||
@@ -216,25 +234,25 @@ golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -247,8 +265,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -257,15 +275,15 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo=
|
||||
google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
|
||||
|
||||
@@ -33,6 +33,7 @@ type EnvVariables struct {
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
|
||||
TestPostgres12Port string `env:"TEST_POSTGRES_12_PORT"`
|
||||
TestPostgres13Port string `env:"TEST_POSTGRES_13_PORT"`
|
||||
TestPostgres14Port string `env:"TEST_POSTGRES_14_PORT"`
|
||||
TestPostgres15Port string `env:"TEST_POSTGRES_15_PORT"`
|
||||
@@ -43,7 +44,19 @@ type EnvVariables struct {
|
||||
TestMinioPort string `env:"TEST_MINIO_PORT"`
|
||||
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
|
||||
|
||||
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
|
||||
|
||||
TestNASPort string `env:"TEST_NAS_PORT"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -135,6 +148,10 @@ func loadEnvVariables() {
|
||||
env.TempFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "temp")
|
||||
|
||||
if env.IsTesting {
|
||||
if env.TestPostgres12Port == "" {
|
||||
log.Error("TEST_POSTGRES_12_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.TestPostgres13Port == "" {
|
||||
log.Error("TEST_POSTGRES_13_PORT is empty")
|
||||
os.Exit(1)
|
||||
@@ -169,10 +186,25 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestAzuriteBlobPort == "" {
|
||||
log.Error("TEST_AZURITE_BLOB_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestNASPort == "" {
|
||||
log.Error("TEST_NAS_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramBotToken == "" {
|
||||
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramChatID == "" {
|
||||
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DowndetectController struct {
|
||||
service *DowndetectService
|
||||
}
|
||||
|
||||
func (c *DowndetectController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/downdetect/is-available", c.IsAvailable)
|
||||
}
|
||||
|
||||
// @Summary Check API availability
|
||||
// @Description Checks if the API service is available
|
||||
// @Tags downdetect
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200
|
||||
// @Failure 500
|
||||
// @Router /downdetect/api [get]
|
||||
func (c *DowndetectController) IsAvailable(ctx *gin.Context) {
|
||||
err := c.service.IsDbAvailable()
|
||||
if err != nil {
|
||||
ctx.JSON(
|
||||
http.StatusInternalServerError,
|
||||
gin.H{"error": fmt.Sprintf("Database is not available: %v", err)},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "API and DB are available"})
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
var downdetectService = &DowndetectService{}
|
||||
var downdetectController = &DowndetectController{
|
||||
downdetectService,
|
||||
}
|
||||
|
||||
func GetDowndetectController() *DowndetectController {
|
||||
return downdetectController
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type DowndetectService struct {
|
||||
}
|
||||
|
||||
func (s *DowndetectService) IsDbAvailable() error {
|
||||
err := storage.GetDb().Exec("SELECT 1").Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
111
backend/internal/features/audit_logs/controller.go
Normal file
111
backend/internal/features/audit_logs/controller.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "postgresus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
154
backend/internal/features/audit_logs/controller_test.go
Normal file
154
backend/internal/features/audit_logs/controller_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
workspaceID := uuid.New()
|
||||
testID := uuid.New().String()
|
||||
|
||||
// Create test logs with unique identifiers
|
||||
userLogMessage := fmt.Sprintf("Test log with user %s", testID)
|
||||
workspaceLogMessage := fmt.Sprintf("Test log with workspace %s", testID)
|
||||
standaloneLogMessage := fmt.Sprintf("Test log standalone %s", testID)
|
||||
|
||||
createAuditLog(service, userLogMessage, &adminUser.UserID, nil)
|
||||
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
|
||||
createAuditLog(service, standaloneLogMessage, nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=100", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
// Verify our specific test logs are present
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, userLogMessage)
|
||||
assert.Contains(t, messages, workspaceLogMessage)
|
||||
assert.Contains(t, messages, standaloneLogMessage)
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
workspaceID := uuid.New()
|
||||
testID := uuid.New().String()
|
||||
|
||||
// Create test logs for different users with unique identifiers
|
||||
user1FirstMessage := fmt.Sprintf("Test log user1 first %s", testID)
|
||||
user1SecondMessage := fmt.Sprintf("Test log user1 second %s", testID)
|
||||
user2FirstMessage := fmt.Sprintf("Test log user2 first %s", testID)
|
||||
user2SecondMessage := fmt.Sprintf("Test log user2 second %s", testID)
|
||||
workspaceLogMessage := fmt.Sprintf("Test workspace log %s", testID)
|
||||
|
||||
createAuditLog(service, user1FirstMessage, &user1.UserID, nil)
|
||||
createAuditLog(service, user1SecondMessage, &user1.UserID, &workspaceID)
|
||||
createAuditLog(service, user2FirstMessage, &user2.UserID, nil)
|
||||
createAuditLog(service, user2SecondMessage, &user2.UserID, &workspaceID)
|
||||
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
// Verify user1's specific logs are present
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, user1FirstMessage)
|
||||
assert.Contains(t, messages, user1SecondMessage)
|
||||
|
||||
// Count only our test logs for user1
|
||||
testLogsCount := 0
|
||||
for _, message := range messages {
|
||||
if message == user1FirstMessage || message == user1SecondMessage {
|
||||
testLogsCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, testLogsCount)
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
|
||||
// Verify user2's specific logs are present
|
||||
ownMessages := extractMessages(ownLogsResponse.AuditLogs)
|
||||
assert.Contains(t, ownMessages, user2FirstMessage)
|
||||
assert.Contains(t, ownMessages, user2SecondMessage)
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithBeforeDateFilter_ReturnsFilteredLogs(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Set filter time to 30 minutes ago
|
||||
beforeTime := baseTime.Add(-30 * time.Minute)
|
||||
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/audit-logs/global?beforeDate=%s&limit=1000",
|
||||
beforeTime.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+adminUser.Token,
|
||||
http.StatusOK,
|
||||
&filteredResponse,
|
||||
)
|
||||
|
||||
// Verify ALL returned logs are older than the filter time
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime),
|
||||
fmt.Sprintf("Log created at %s should be before filter time %s",
|
||||
log.CreatedAt.Format(time.RFC3339), beforeTime.Format(time.RFC3339)))
|
||||
}
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
29
backend/internal/features/audit_logs/di.go
Normal file
29
backend/internal/features/audit_logs/di.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
31
backend/internal/features/audit_logs/dto.go
Normal file
31
backend/internal/features/audit_logs/dto.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLogDTO `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type AuditLogDTO struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
UserEmail *string `json:"userEmail" gorm:"column:user_email"`
|
||||
UserName *string `json:"userName" gorm:"column:user_name"`
|
||||
WorkspaceName *string `json:"workspaceName" gorm:"column:workspace_name"`
|
||||
}
|
||||
19
backend/internal/features/audit_logs/models.go
Normal file
19
backend/internal/features/audit_logs/models.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
139
backend/internal/features/audit_logs/repository.go
Normal file
139
backend/internal/features/audit_logs/repository.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByWorkspace(
|
||||
workspaceID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
137
backend/internal/features/audit_logs/service.go
Normal file
137
backend/internal/features/audit_logs/service.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
user_models "postgresus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
workspaceID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
WorkspaceID: workspaceID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(
|
||||
targetUserID,
|
||||
limit,
|
||||
offset,
|
||||
request.BeforeDate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetWorkspaceAuditLogs(
|
||||
workspaceID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByWorkspace(
|
||||
workspaceID,
|
||||
limit,
|
||||
offset,
|
||||
request.BeforeDate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
83
backend/internal/features/audit_logs/service_test.go
Normal file
83
backend/internal/features/audit_logs/service_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
workspace1ID, workspace2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for workspaces
|
||||
createAuditLog(service, "Test workspace1 log first", &user1.UserID, &workspace1ID)
|
||||
createAuditLog(service, "Test workspace1 log second", &user2.UserID, &workspace1ID)
|
||||
createAuditLog(service, "Test workspace2 log first", &user1.UserID, &workspace2ID)
|
||||
createAuditLog(service, "Test workspace2 log second", &user2.UserID, &workspace2ID)
|
||||
createAuditLog(service, "Test no workspace log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test workspace 1 logs
|
||||
workspace1Response, err := service.GetWorkspaceAuditLogs(workspace1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(workspace1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(workspace1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test workspace1 log first")
|
||||
assert.Contains(t, messages, "Test workspace1 log second")
|
||||
for _, log := range workspace1Response.AuditLogs {
|
||||
assert.Equal(t, &workspace1ID, log.WorkspaceID)
|
||||
}
|
||||
|
||||
// Test workspace 2 logs
|
||||
workspace2Response, err := service.GetWorkspaceAuditLogs(workspace2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(workspace2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(workspace2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test workspace2 log first")
|
||||
assert.Contains(t, messages2, "Test workspace2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
assert.NotNil(t, log.UserEmail, "User email should be present for logs with user_id")
|
||||
assert.NotNil(
|
||||
t,
|
||||
log.WorkspaceName,
|
||||
"Workspace name should be present for logs with workspace_id",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, workspaceID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, workspaceID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLogDTO) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"postgresus-backend/internal/config"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/period"
|
||||
"time"
|
||||
)
|
||||
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,10 +18,12 @@ import (
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -40,11 +44,8 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
@@ -67,16 +68,20 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -97,11 +102,8 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
@@ -124,16 +126,20 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -157,11 +163,8 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -185,16 +188,20 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -218,11 +225,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -246,16 +250,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -280,11 +288,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -309,6 +314,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if !exists {
|
||||
return errors.New("backup is not in progress or already completed")
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
type BackupController struct {
|
||||
backupService *BackupService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -20,51 +19,48 @@ func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backups", c.MakeBackup)
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
router.DELETE("/backups/:id", c.DeleteBackup)
|
||||
router.POST("/backups/:id/cancel", c.CancelBackup)
|
||||
}
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get all backups for the specified database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Success 200 {array} Backup
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [get]
|
||||
func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
databaseIDStr := ctx.Query("database_id")
|
||||
if databaseIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "database_id query parameter is required"})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(databaseIDStr)
|
||||
var request GetBackupsRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(request.DatabaseID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
backups, err := c.backupService.GetBackups(user, databaseID)
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, backups)
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// MakeBackup
|
||||
@@ -80,24 +76,18 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups [post]
|
||||
func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request MakeBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -117,24 +107,18 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups/{id} [delete]
|
||||
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.DeleteBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -143,6 +127,37 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// CancelBackup
|
||||
// @Summary Cancel an in-progress backup
|
||||
// @Description Cancel a backup that is currently in progress
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/cancel [post]
|
||||
func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.CancelBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup
|
||||
@@ -154,24 +169,18 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, err := c.backupService.GetBackupFile(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -179,19 +188,16 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
}
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
// Log the error but don't interrupt the response
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set headers for file download
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"backup_%s.dump\"", id.String()),
|
||||
)
|
||||
|
||||
// Stream the file content
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
|
||||
|
||||
709
backend/internal/features/backups/backups/controller_test.go
Normal file
709
backend/internal/features/backups/backups/controller_test.go
Normal file
@@ -0,0 +1,709 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
users_dto "postgresus-backend/internal/features/users/dto"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_models "postgresus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups?database_id=%s", database.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GetBackupsResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(response.Backups), 1)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(1))
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "backup started successfully")
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup manually initiated") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup creation not found")
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace member can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
} else {
|
||||
userService := users_services.GetUserService()
|
||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup deleted") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup file downloaded") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup download not found")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/cancel", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
// Verify audit log was created
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
userService := users_services.GetUserService()
|
||||
adminUser, err := userService.GetUserFromToken(admin.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetGlobalAuditLogs(
|
||||
adminUser,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabase(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
storage := &storages.Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test Storage " + uuid.New().String(),
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
repo := &storages.StorageRepository{}
|
||||
storage, err := repo.Save(storage)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func enableBackupForDatabase(databaseID uuid.UUID) {
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a dummy backup file for testing download functionality
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
@@ -1,17 +1,23 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"time"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
@@ -19,9 +25,14 @@ var backupService = &BackupService{
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
@@ -35,7 +46,6 @@ var backupBackgroundService = &BackupBackgroundService{
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
|
||||
28
backend/internal/features/backups/backups/dto.go
Normal file
28
backend/internal/features/backups/backups/dto.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"io"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DecryptionReader struct {
|
||||
baseReader io.Reader
|
||||
cipher cipher.AEAD
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
chunkIndex uint64
|
||||
headerRead bool
|
||||
eof bool
|
||||
}
|
||||
|
||||
func NewDecryptionReader(
|
||||
baseReader io.Reader,
|
||||
masterKey string,
|
||||
backupID uuid.UUID,
|
||||
salt []byte,
|
||||
nonce []byte,
|
||||
) (*DecryptionReader, error) {
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
|
||||
}
|
||||
if len(nonce) != NonceLen {
|
||||
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive backup key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(derivedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
reader := &DecryptionReader{
|
||||
baseReader,
|
||||
aesgcm,
|
||||
make([]byte, 0),
|
||||
nonce,
|
||||
0,
|
||||
false,
|
||||
false,
|
||||
}
|
||||
|
||||
if err := reader.readAndValidateHeader(salt, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
for len(r.buffer) < len(p) && !r.eof {
|
||||
if err := r.readAndDecryptChunk(); err != nil {
|
||||
if err == io.EOF {
|
||||
r.eof = true
|
||||
break
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.buffer) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(p, r.buffer)
|
||||
r.buffer = r.buffer[n:]
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) readAndValidateHeader(expectedSalt, expectedNonce []byte) error {
|
||||
header := make([]byte, HeaderLen)
|
||||
|
||||
if _, err := io.ReadFull(r.baseReader, header); err != nil {
|
||||
return fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
magic := string(header[0:MagicBytesLen])
|
||||
if magic != MagicBytes {
|
||||
return fmt.Errorf("invalid magic bytes: expected %s, got %s", MagicBytes, magic)
|
||||
}
|
||||
|
||||
salt := header[MagicBytesLen : MagicBytesLen+SaltLen]
|
||||
nonce := header[MagicBytesLen+SaltLen : MagicBytesLen+SaltLen+NonceLen]
|
||||
|
||||
if string(salt) != string(expectedSalt) {
|
||||
return fmt.Errorf("salt mismatch in file header")
|
||||
}
|
||||
|
||||
if string(nonce) != string(expectedNonce) {
|
||||
return fmt.Errorf("nonce mismatch in file header")
|
||||
}
|
||||
|
||||
r.headerRead = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) readAndDecryptChunk() error {
|
||||
lengthBuf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r.baseReader, lengthBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunkLen := binary.BigEndian.Uint32(lengthBuf)
|
||||
if chunkLen == 0 || chunkLen > ChunkSize+16 {
|
||||
return fmt.Errorf("invalid chunk length: %d", chunkLen)
|
||||
}
|
||||
|
||||
encrypted := make([]byte, chunkLen)
|
||||
if _, err := io.ReadFull(r.baseReader, encrypted); err != nil {
|
||||
return fmt.Errorf("failed to read encrypted chunk: %w", err)
|
||||
}
|
||||
|
||||
chunkNonce := r.generateChunkNonce()
|
||||
|
||||
decrypted, err := r.cipher.Open(nil, chunkNonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrypt chunk (authentication failed - file may be corrupted or tampered): %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
r.buffer = append(r.buffer, decrypted...)
|
||||
r.chunkIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) generateChunkNonce() []byte {
|
||||
chunkNonce := make([]byte, NonceLen)
|
||||
copy(chunkNonce, r.nonce)
|
||||
|
||||
binary.BigEndian.PutUint64(chunkNonce[4:], r.chunkIndex)
|
||||
|
||||
return chunkNonce
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type EncryptionWriter struct {
|
||||
baseWriter io.Writer
|
||||
cipher cipher.AEAD
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
salt []byte
|
||||
chunkIndex uint64
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func NewEncryptionWriter(
|
||||
baseWriter io.Writer,
|
||||
masterKey string,
|
||||
backupID uuid.UUID,
|
||||
salt []byte,
|
||||
nonce []byte,
|
||||
) (*EncryptionWriter, error) {
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
|
||||
}
|
||||
if len(nonce) != NonceLen {
|
||||
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive backup key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(derivedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
writer := &EncryptionWriter{
|
||||
baseWriter: baseWriter,
|
||||
cipher: aesgcm,
|
||||
buffer: make([]byte, 0, ChunkSize),
|
||||
nonce: nonce,
|
||||
chunkIndex: 0,
|
||||
headerWritten: false,
|
||||
salt: salt, // Store salt for lazy header writing
|
||||
}
|
||||
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) Write(p []byte) (n int, err error) {
|
||||
// Write header on first write (lazy initialization)
|
||||
if !w.headerWritten {
|
||||
if err := w.writeHeader(w.salt, w.nonce); err != nil {
|
||||
return 0, fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
w.buffer = append(w.buffer, p...)
|
||||
|
||||
for len(w.buffer) >= ChunkSize {
|
||||
chunk := w.buffer[:ChunkSize]
|
||||
if err := w.encryptAndWriteChunk(chunk); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.buffer = w.buffer[ChunkSize:]
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) Close() error {
|
||||
// Write header if it hasn't been written yet (in case Close is called without any writes)
|
||||
if !w.headerWritten {
|
||||
if err := w.writeHeader(w.salt, w.nonce); err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(w.buffer) > 0 {
|
||||
if err := w.encryptAndWriteChunk(w.buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) writeHeader(salt, nonce []byte) error {
|
||||
header := make([]byte, HeaderLen)
|
||||
|
||||
copy(header[0:MagicBytesLen], []byte(MagicBytes))
|
||||
copy(header[MagicBytesLen:MagicBytesLen+SaltLen], salt)
|
||||
copy(header[MagicBytesLen+SaltLen:MagicBytesLen+SaltLen+NonceLen], nonce)
|
||||
|
||||
_, err := w.baseWriter.Write(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
w.headerWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) encryptAndWriteChunk(chunk []byte) error {
|
||||
chunkNonce := w.generateChunkNonce()
|
||||
|
||||
encrypted := w.cipher.Seal(nil, chunkNonce, chunk, nil)
|
||||
|
||||
lengthBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lengthBuf, uint32(len(encrypted)))
|
||||
|
||||
if _, err := w.baseWriter.Write(lengthBuf); err != nil {
|
||||
return fmt.Errorf("failed to write chunk length: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.baseWriter.Write(encrypted); err != nil {
|
||||
return fmt.Errorf("failed to write encrypted chunk: %w", err)
|
||||
}
|
||||
|
||||
w.chunkIndex++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) generateChunkNonce() []byte {
|
||||
chunkNonce := make([]byte, NonceLen)
|
||||
copy(chunkNonce, w.nonce)
|
||||
|
||||
binary.BigEndian.PutUint64(chunkNonce[4:], w.chunkIndex)
|
||||
|
||||
return chunkNonce
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_EncryptDecryptRoundTrip_ReturnsOriginalData(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte(
|
||||
"This is a test backup data that should be encrypted and then decrypted successfully.",
|
||||
)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted := make([]byte, len(originalData))
|
||||
n, err = io.ReadFull(reader, decrypted)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_EncryptDecryptRoundTrip_LargeData_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := make([]byte, 100*1024)
|
||||
_, err = rand.Read(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_MultipleWrites_CombinesCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
part1 := []byte("First part of data. ")
|
||||
part2 := []byte("Second part of data. ")
|
||||
part3 := []byte("Third part of data.")
|
||||
expectedData := append(append(part1, part2...), part3...)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(part1)
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(part2)
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(part3)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedData, decrypted)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_InvalidHeader_ReturnsError(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidHeader := make([]byte, HeaderLen)
|
||||
copy(invalidHeader, []byte("INVALID!"))
|
||||
|
||||
invalidData := bytes.NewBuffer(invalidHeader)
|
||||
|
||||
_, err = NewDecryptionReader(invalidData, masterKey, backupID, salt, nonce)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid magic bytes")
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_TamperedData_ReturnsError(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("This data will be tampered with.")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedBytes := encrypted.Bytes()
|
||||
if len(encryptedBytes) > HeaderLen+10 {
|
||||
encryptedBytes[HeaderLen+10] ^= 0xFF
|
||||
}
|
||||
|
||||
tamperedBuffer := bytes.NewBuffer(encryptedBytes)
|
||||
|
||||
reader, err := NewDecryptionReader(tamperedBuffer, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
}
|
||||
|
||||
func Test_DeriveBackupKey_SameInputs_ReturnsSameKey(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
key1, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
key2, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, key1, key2)
|
||||
}
|
||||
|
||||
func Test_DeriveBackupKey_DifferentInputs_ReturnsDifferentKeys(t *testing.T) {
|
||||
masterKey1 := uuid.New().String() + uuid.New().String()
|
||||
masterKey2 := uuid.New().String() + uuid.New().String()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
salt1, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
salt2, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
key1, err := DeriveBackupKey(masterKey1, backupID1, salt1)
|
||||
require.NoError(t, err)
|
||||
|
||||
key2, err := DeriveBackupKey(masterKey2, backupID1, salt1)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key2)
|
||||
|
||||
key3, err := DeriveBackupKey(masterKey1, backupID2, salt1)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key3)
|
||||
|
||||
key4, err := DeriveBackupKey(masterKey1, backupID1, salt2)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key4)
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_PartialChunk_HandledCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
smallData := []byte("Small data less than chunk size")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(smallData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, smallData, decrypted)
|
||||
}
|
||||
|
||||
func Test_GenerateSalt_ReturnsCorrectLength(t *testing.T) {
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SaltLen, len(salt))
|
||||
}
|
||||
|
||||
func Test_GenerateSalt_GeneratesUniqueSalts(t *testing.T) {
|
||||
salt1, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
salt2, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, salt1, salt2)
|
||||
}
|
||||
|
||||
func Test_GenerateNonce_ReturnsCorrectLength(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
func Test_GenerateNonce_GeneratesUniqueNonces(t *testing.T) {
|
||||
nonce1, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, nonce1, nonce2)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_WrongMasterKey_ReturnsError(t *testing.T) {
|
||||
masterKey1 := uuid.New().String() + uuid.New().String()
|
||||
masterKey2 := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("Secret data")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey1, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey2, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_EmptyData_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(decrypted))
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_MultipleChunks_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
dataSize := ChunkSize*3 + 1000
|
||||
originalData := make([]byte, dataSize)
|
||||
_, err = rand.Read(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_SmallReads_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("This is test data that will be read in small chunks.")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decrypted []byte
|
||||
buf := make([]byte, 5)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
decrypted = append(decrypted, buf[:n]...)
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
MagicBytes = "PGRSUS01"
|
||||
MagicBytesLen = 8
|
||||
SaltLen = 32
|
||||
NonceLen = 12
|
||||
ReservedLen = 12
|
||||
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
|
||||
ChunkSize = 32 * 1024
|
||||
PBKDF2Iterations = 100000
|
||||
)
|
||||
|
||||
func DeriveBackupKey(masterKey string, backupID uuid.UUID, salt []byte) ([]byte, error) {
|
||||
if masterKey == "" {
|
||||
return nil, fmt.Errorf("master key cannot be empty")
|
||||
}
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes", SaltLen)
|
||||
}
|
||||
|
||||
keyMaterial := []byte(masterKey + backupID.String())
|
||||
|
||||
derivedKey := pbkdf2.Key(keyMaterial, salt, PBKDF2Iterations, 32, sha256.New)
|
||||
|
||||
return derivedKey, nil
|
||||
}
|
||||
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
salt := make([]byte, SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func GenerateNonce() ([]byte, error) {
|
||||
nonce := make([]byte, NonceLen)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
BackupStatusInProgress BackupStatus = "IN_PROGRESS"
|
||||
BackupStatusCompleted BackupStatus = "COMPLETED"
|
||||
BackupStatusFailed BackupStatus = "FAILED"
|
||||
BackupStatusCanceled BackupStatus = "CANCELED"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
@@ -19,6 +22,7 @@ type NotificationSender interface {
|
||||
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -26,7 +30,7 @@ type CreateBackupUsecase interface {
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error
|
||||
) (*usecases_postgresql.BackupMetadata, error)
|
||||
}
|
||||
|
||||
type BackupRemoveListener interface {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -11,11 +10,8 @@ import (
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
|
||||
Database *databases.Database `json:"database" gorm:"foreignKey:DatabaseID"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
|
||||
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
@@ -24,5 +20,9 @@ type Backup struct {
|
||||
|
||||
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
|
||||
|
||||
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
|
||||
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
|
||||
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
@@ -13,18 +13,20 @@ import (
|
||||
type BackupRepository struct{}
|
||||
|
||||
func (r *BackupRepository) Save(backup *Backup) error {
|
||||
if backup.DatabaseID == uuid.Nil || backup.StorageID == uuid.Nil {
|
||||
return errors.New("database ID and storage ID are required")
|
||||
}
|
||||
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := backup.ID == uuid.Nil
|
||||
if isNew {
|
||||
backup.ID = uuid.New()
|
||||
return db.Create(backup).
|
||||
Omit("Database", "Storage").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(backup).
|
||||
Omit("Database", "Storage").
|
||||
Error
|
||||
}
|
||||
|
||||
@@ -33,8 +35,6 @@ func (r *BackupRepository) FindByDatabaseID(databaseID uuid.UUID) ([]*Backup, er
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -56,8 +56,6 @@ func (r *BackupRepository) FindByDatabaseIDWithLimit(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
@@ -73,8 +71,6 @@ func (r *BackupRepository) FindByStorageID(storageID uuid.UUID) ([]*Backup, erro
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("storage_id = ?", storageID).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -89,8 +85,6 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error; err != nil {
|
||||
@@ -109,8 +103,6 @@ func (r *BackupRepository) FindByID(id uuid.UUID) (*Backup, error) {
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("id = ?", id).
|
||||
First(&backup).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -124,8 +116,6 @@ func (r *BackupRepository) FindByStatus(status BackupStatus) ([]*Backup, error)
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -143,8 +133,6 @@ func (r *BackupRepository) FindByStorageIdAndStatus(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("storage_id = ? AND status = ?", storageID, status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -162,8 +150,6 @@ func (r *BackupRepository) FindByDatabaseIdAndStatus(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ? AND status = ?", databaseID, status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -185,8 +171,6 @@ func (r *BackupRepository) FindBackupsBeforeDate(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ? AND created_at < ?", databaseID, date).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -195,3 +179,36 @@ func (r *BackupRepository) FindBackupsBeforeDate(
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindByDatabaseIDWithPagination(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Where("database_id = ?", databaseID).
|
||||
Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -1,16 +1,24 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,12 +31,18 @@ type BackupService struct {
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
@@ -62,34 +76,74 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot create backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canAccess {
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackups(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) ([]*Backup, error) {
|
||||
limit, offset int,
|
||||
) (*GetBackupsResponse, error) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot get backups for database without workspace")
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseID(databaseID)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access backups for this database")
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
total, err := s.backupRepository.CountByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetBackupsResponse{
|
||||
Backups: backups,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(
|
||||
@@ -101,14 +155,37 @@ func (s *BackupService) DeleteBackup(
|
||||
return err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this backup")
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot delete backup for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup deleted for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.deleteBackup(backup)
|
||||
}
|
||||
|
||||
@@ -154,10 +231,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
Database: database,
|
||||
|
||||
StorageID: storage.ID,
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
@@ -184,7 +258,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
}
|
||||
}
|
||||
|
||||
err = s.createBackupUseCase.Execute(
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
@@ -193,6 +272,34 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
@@ -225,6 +332,13 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
@@ -265,6 +379,11 @@ func (s *BackupService) SendBackupNotification(
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
@@ -276,9 +395,17 @@ func (s *BackupService) SendBackupNotification(
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf("❌ Backup failed for database \"%s\"", database.Name)
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf("✅ Backup completed for database \"%s\"", database.Name)
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
@@ -319,6 +446,53 @@ func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
func (s *BackupService) CancelBackup(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) error {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel backup for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
@@ -328,16 +502,37 @@ func (s *BackupService) GetBackupFile(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this backup")
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storage.GetFile(backup.ID)
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot download backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*database.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to download backup for this database")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.getBackupReader(backupID)
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
@@ -352,9 +547,12 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error
|
||||
s.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return s.backupRepository.DeleteByID(backup.ID)
|
||||
@@ -389,3 +587,91 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBackupReader returns a reader for the backup file
|
||||
// If encrypted, wraps with DecryptionReader
|
||||
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find backup: %w", err)
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
|
||||
// If not encrypted, return raw reader
|
||||
if backup.Encryption == backups_config.BackupEncryptionNone {
|
||||
s.logger.Info("Returning non-encrypted backup", "backupId", backupID)
|
||||
return fileReader, nil
|
||||
}
|
||||
|
||||
// Decrypt on-the-fly for encrypted backups
|
||||
if backup.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported encryption type: %s", backup.Encryption)
|
||||
}
|
||||
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", err)
|
||||
}
|
||||
return nil, fmt.Errorf("backup marked as encrypted but missing encryption metadata")
|
||||
}
|
||||
|
||||
// Get master key
|
||||
masterKey, err := s.secretKeyRepo.GetSecretKey()
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and IV
|
||||
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode salt: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode IV: %w", err)
|
||||
}
|
||||
|
||||
// Wrap with decrypting reader
|
||||
decryptionReader, err := encryption.NewDecryptionReader(
|
||||
fileReader,
|
||||
masterKey,
|
||||
backup.ID,
|
||||
salt,
|
||||
iv,
|
||||
)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create decrypting reader: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -17,15 +25,27 @@ import (
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer storages.RemoveTestStorage(storage.ID)
|
||||
defer notifiers.RemoveTestNotifier(notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
@@ -36,9 +56,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
// Set up expectations
|
||||
@@ -79,9 +104,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
@@ -99,9 +129,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
// capture arguments
|
||||
@@ -137,6 +172,7 @@ type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -144,15 +180,16 @@ func (uc *CreateFailedBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
return errors.New("backup failed")
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -160,7 +197,11 @@ func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
return nil
|
||||
return &usecases_postgresql.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
20
backend/internal/features/backups/backups/testing.go
Normal file
20
backend/internal/features/backups/backups/testing.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
return workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
GetBackupController(),
|
||||
)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
@@ -14,8 +15,9 @@ type CreateBackupUsecase struct {
|
||||
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database and returns the backup size in MB
|
||||
// Execute creates a backup of the database and returns the backup metadata
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -23,9 +25,10 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
database,
|
||||
@@ -34,5 +37,5 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
)
|
||||
}
|
||||
|
||||
return errors.New("database type not supported")
|
||||
return nil, errors.New("database type not supported")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package usecases_postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,21 +15,40 @@ import (
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
backupTimeout = 23 * time.Hour
|
||||
shutdownCheckInterval = 1 * time.Second
|
||||
copyBufferSize = 32 * 1024
|
||||
progressReportIntervalMB = 1.0
|
||||
pgConnectTimeout = 30
|
||||
compressionLevel = 5
|
||||
defaultBackupLimit = 1000
|
||||
exitCodeAccessViolation = -1073741819
|
||||
exitCodeGenericError = 1
|
||||
exitCodeConnectionError = 2
|
||||
)
|
||||
|
||||
type CreatePostgresqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
logger *slog.Logger
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
@@ -36,7 +56,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*BackupMetadata, error) {
|
||||
uc.logger.Info(
|
||||
"Creating PostgreSQL backup via pg_dump custom format",
|
||||
"databaseId",
|
||||
@@ -46,41 +66,28 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
pg := db.Postgresql
|
||||
|
||||
if pg == nil {
|
||||
return fmt.Errorf("postgresql database configuration is required for pg_dump backups")
|
||||
return nil, fmt.Errorf("postgresql database configuration is required for pg_dump backups")
|
||||
}
|
||||
|
||||
if pg.Database == nil || *pg.Database == "" {
|
||||
return fmt.Errorf("database name is required for pg_dump backups")
|
||||
return nil, fmt.Errorf("database name is required for pg_dump backups")
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-Fc", // custom format with built-in compression
|
||||
"--no-password", // Use environment variable for password, prevent prompts
|
||||
"-h", pg.Host,
|
||||
"-p", strconv.Itoa(pg.Port),
|
||||
"-U", pg.Username,
|
||||
"-d", *pg.Database,
|
||||
"--verbose", // Add verbose output to help with debugging
|
||||
}
|
||||
args := uc.buildPgDumpArgs(pg)
|
||||
|
||||
// Use zstd compression level 5 for PostgreSQL 15+ (better compression and speed)
|
||||
// Fall back to gzip compression level 5 for older versions
|
||||
if pg.Version == tools.PostgresqlVersion13 || pg.Version == tools.PostgresqlVersion14 ||
|
||||
pg.Version == tools.PostgresqlVersion15 {
|
||||
args = append(args, "-Z", "5")
|
||||
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
|
||||
} else {
|
||||
args = append(args, "--compress=zstd:5")
|
||||
uc.logger.Info("Using zstd compression level 5", "version", pg.Version)
|
||||
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, pg.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
|
||||
}
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
tools.GetPostgresqlExecutable(
|
||||
@@ -90,7 +97,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
config.GetEnv().PostgresesInstallDir,
|
||||
),
|
||||
args,
|
||||
pg.Password,
|
||||
decryptedPassword,
|
||||
storage,
|
||||
db,
|
||||
backupProgressListener,
|
||||
@@ -99,6 +106,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// streamToStorage streams pg_dump output directly to storage
|
||||
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
pgBin string,
|
||||
@@ -107,36 +115,15 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storage *storages.Storage,
|
||||
db *databases.Database,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) error {
|
||||
) (*BackupMetadata, error) {
|
||||
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
|
||||
|
||||
// if backup not fit into 23 hours, Postgresus
|
||||
// seems not to work for such database size
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 23*time.Hour)
|
||||
ctx, cancel := uc.createBackupContext(parentCtx)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create temporary .pgpass file as a more reliable alternative to PGPASSWORD
|
||||
pgpassFile, err := uc.createTempPgpassFile(db.Postgresql, password)
|
||||
pgpassFile, err := uc.setupPgpassFile(db.Postgresql, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if pgpassFile != "" {
|
||||
@@ -144,87 +131,21 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify .pgpass file was created successfully
|
||||
if pgpassFile == "" {
|
||||
return fmt.Errorf("temporary .pgpass file was not created")
|
||||
}
|
||||
|
||||
// Verify .pgpass file was created correctly
|
||||
if info, err := os.Stat(pgpassFile); err == nil {
|
||||
uc.logger.Info("Temporary .pgpass file created successfully",
|
||||
"pgpassFile", pgpassFile,
|
||||
"size", info.Size(),
|
||||
"mode", info.Mode(),
|
||||
)
|
||||
} else {
|
||||
return fmt.Errorf("failed to verify .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, pgBin, args...)
|
||||
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
|
||||
|
||||
// Start with system environment variables to preserve Windows PATH, SystemRoot, etc.
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// Use the .pgpass file for authentication
|
||||
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
|
||||
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
|
||||
|
||||
// Debug password setup (without exposing the actual password)
|
||||
uc.logger.Info("Setting up PostgreSQL environment",
|
||||
"passwordLength", len(password),
|
||||
"passwordEmpty", password == "",
|
||||
"pgBin", pgBin,
|
||||
"usingPgpassFile", true,
|
||||
"parallelJobs", backupConfig.CpuCount,
|
||||
)
|
||||
|
||||
// Add PostgreSQL-specific environment variables
|
||||
cmd.Env = append(cmd.Env, "PGCLIENTENCODING=UTF8")
|
||||
cmd.Env = append(cmd.Env, "PGCONNECT_TIMEOUT=30")
|
||||
|
||||
// Add encoding-related environment variables to handle character encoding issues
|
||||
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
|
||||
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
|
||||
|
||||
// Add PostgreSQL-specific encoding settings
|
||||
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
|
||||
|
||||
shouldRequireSSL := db.Postgresql.IsHttps
|
||||
|
||||
// Require SSL when explicitly configured
|
||||
if shouldRequireSSL {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
|
||||
uc.logger.Info("Using required SSL mode", "configuredHttps", db.Postgresql.IsHttps)
|
||||
} else {
|
||||
// SSL not explicitly required, but prefer it if available
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
|
||||
uc.logger.Info("Using preferred SSL mode", "configuredHttps", db.Postgresql.IsHttps)
|
||||
}
|
||||
|
||||
// Set other SSL parameters to avoid certificate issues
|
||||
cmd.Env = append(cmd.Env, "PGSSLCERT=") // No client certificate
|
||||
cmd.Env = append(cmd.Env, "PGSSLKEY=") // No client key
|
||||
cmd.Env = append(cmd.Env, "PGSSLROOTCERT=") // No root certificate verification
|
||||
cmd.Env = append(cmd.Env, "PGSSLCRL=") // No certificate revocation list
|
||||
|
||||
// Verify executable exists and is accessible
|
||||
if _, err := exec.LookPath(pgBin); err != nil {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL executable not found or not accessible: %s - %w",
|
||||
pgBin,
|
||||
err,
|
||||
)
|
||||
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, backupConfig.CpuCount, pgBin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pgStdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
return nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
pgStderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
return nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Capture stderr in a separate goroutine to ensure we don't miss any error output
|
||||
@@ -234,23 +155,31 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
stderrCh <- stderrOutput
|
||||
}()
|
||||
|
||||
// A pipe connecting pg_dump output → storage
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
// Create a counting writer to track bytes
|
||||
countingWriter := &CountingWriter{writer: storageWriter}
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
countingWriter := &CountingWriter{writer: finalWriter}
|
||||
|
||||
// The backup ID becomes the object key / filename in storage
|
||||
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErrCh <- storage.SaveFile(uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
// Start pg_dump
|
||||
if err = cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
|
||||
return nil, fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
|
||||
}
|
||||
|
||||
// Copy pg output directly to storage with shutdown checks
|
||||
@@ -272,23 +201,17 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
bytesWritten := <-bytesWrittenCh
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
|
||||
if err := pipeWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close counting writer", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
<-saveErrCh // Wait for storage to finish
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
// Check for shutdown or cancellation before finalizing
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
|
||||
return nil, uc.checkCancellationReason()
|
||||
default:
|
||||
}
|
||||
|
||||
// Close the pipe writer to signal end of data
|
||||
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
|
||||
if err := pipeWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close counting writer", "error", err)
|
||||
}
|
||||
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
|
||||
<-saveErrCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait until storage ends reading
|
||||
@@ -303,134 +226,34 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
|
||||
switch {
|
||||
case waitErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enhanced error handling for PostgreSQL connection and SSL issues
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
// Check for specific PostgreSQL error patterns
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
// Enhanced debugging for exit status 1 with empty stderr
|
||||
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
|
||||
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"env_vars", []string{
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT=30",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
},
|
||||
)
|
||||
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s failed with exit status 1 but provided no error details. "+
|
||||
"This often indicates: "+
|
||||
"1) Connection timeout or refused connection, "+
|
||||
"2) Authentication failure with incorrect credentials, "+
|
||||
"3) Database does not exist, "+
|
||||
"4) Network connectivity issues, "+
|
||||
"5) PostgreSQL server not running. "+
|
||||
"Command executed: %s %s",
|
||||
filepath.Base(pgBin),
|
||||
pgBin,
|
||||
strings.Join(args, " "),
|
||||
)
|
||||
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
|
||||
uc.logger.Error("PostgreSQL tool crashed with access violation",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"exitCode", fmt.Sprintf("0x%X", uint32(exitCode)),
|
||||
)
|
||||
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
stderrStr,
|
||||
)
|
||||
} else if exitCode == 1 || exitCode == 2 {
|
||||
// Check for common connection and authentication issues
|
||||
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection rejected by server configuration (pg_hba.conf). The server may not allow connections from your IP address or may require different authentication settings. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed - no password supplied. "+
|
||||
"PGPASSWORD environment variable may not be working correctly on this system. "+
|
||||
"Password length: %d, Password empty: %v. "+
|
||||
"Consider using a .pgpass file as an alternative. stderr: %s",
|
||||
len(password),
|
||||
password == "",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL SSL connection failed. The server may require SSL encryption or have SSL configuration issues. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection refused. Check if the server is running and accessible from your network. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "timeout") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
return nil, uc.buildPgDumpErrorMessage(waitErr, stderrOutput, pgBin, args, password)
|
||||
case copyErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fmt.Errorf("copy to storage: %w", copyErr)
|
||||
return nil, fmt.Errorf("copy to storage: %w", copyErr)
|
||||
case saveErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fmt.Errorf("save to storage: %w", saveErr)
|
||||
return nil, fmt.Errorf("save to storage: %w", saveErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
return &backupMetadata, nil
|
||||
}
|
||||
|
||||
// copyWithShutdownCheck copies data from src to dst while checking for shutdown
|
||||
func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (int64, error) {
|
||||
buf := make([]byte, 32*1024) // 32KB buffer
|
||||
buf := make([]byte, copyBufferSize)
|
||||
var totalBytesWritten int64
|
||||
|
||||
// Progress reporting interval - report every 1MB of data
|
||||
var lastReportedMB float64
|
||||
const reportIntervalMB = 1.0
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -463,12 +286,9 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
|
||||
// Report progress based on total size
|
||||
if backupProgressListener != nil {
|
||||
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
|
||||
|
||||
// Only report if we've written at least 1MB more data than last report
|
||||
if currentSizeMB >= lastReportedMB+reportIntervalMB {
|
||||
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
|
||||
backupProgressListener(currentSizeMB)
|
||||
lastReportedMB = currentSizeMB
|
||||
}
|
||||
@@ -479,7 +299,6 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -487,12 +306,412 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
// containsIgnoreCase checks if a string contains a substring, ignoring case
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlDatabase) []string {
|
||||
args := []string{
|
||||
"-Fc",
|
||||
"--no-password",
|
||||
"-h", pg.Host,
|
||||
"-p", strconv.Itoa(pg.Port),
|
||||
"-U", pg.Username,
|
||||
"-d", *pg.Database,
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
compressionArgs := uc.getCompressionArgs(pg.Version)
|
||||
return append(args, compressionArgs...)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) getCompressionArgs(
|
||||
version tools.PostgresqlVersion,
|
||||
) []string {
|
||||
if uc.isOlderPostgresVersion(version) {
|
||||
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", version)
|
||||
return []string{"-Z", strconv.Itoa(compressionLevel)}
|
||||
}
|
||||
|
||||
uc.logger.Info("Using zstd compression level 5", "version", version)
|
||||
return []string{fmt.Sprintf("--compress=zstd:%d", compressionLevel)}
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) isOlderPostgresVersion(
|
||||
version tools.PostgresqlVersion,
|
||||
) bool {
|
||||
return version == tools.PostgresqlVersion12 ||
|
||||
version == tools.PostgresqlVersion13 ||
|
||||
version == tools.PostgresqlVersion14 ||
|
||||
version == tools.PostgresqlVersion15
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
|
||||
parentCtx context.Context,
|
||||
) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(shutdownCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupPgpassFile(
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
pgpassFile, err := uc.createTempPgpassFile(pgConfig, password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
if pgpassFile == "" {
|
||||
return "", fmt.Errorf("temporary .pgpass file was not created")
|
||||
}
|
||||
|
||||
if info, err := os.Stat(pgpassFile); err == nil {
|
||||
uc.logger.Info("Temporary .pgpass file created successfully",
|
||||
"pgpassFile", pgpassFile,
|
||||
"size", info.Size(),
|
||||
"mode", info.Mode(),
|
||||
)
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to verify .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
return pgpassFile, nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupPgEnvironment(
|
||||
cmd *exec.Cmd,
|
||||
pgpassFile string,
|
||||
shouldRequireSSL bool,
|
||||
password string,
|
||||
cpuCount int,
|
||||
pgBin string,
|
||||
) error {
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
|
||||
|
||||
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
|
||||
uc.logger.Info("Setting up PostgreSQL environment",
|
||||
"passwordLength", len(password),
|
||||
"passwordEmpty", password == "",
|
||||
"pgBin", pgBin,
|
||||
"usingPgpassFile", true,
|
||||
"parallelJobs", cpuCount,
|
||||
)
|
||||
|
||||
cmd.Env = append(cmd.Env,
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT="+strconv.Itoa(pgConnectTimeout),
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
)
|
||||
|
||||
if shouldRequireSSL {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
|
||||
uc.logger.Info("Using required SSL mode", "configuredHttps", shouldRequireSSL)
|
||||
} else {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
|
||||
uc.logger.Info("Using preferred SSL mode", "configuredHttps", shouldRequireSSL)
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env,
|
||||
"PGSSLCERT=",
|
||||
"PGSSLKEY=",
|
||||
"PGSSLROOTCERT=",
|
||||
"PGSSLCRL=",
|
||||
)
|
||||
|
||||
if _, err := exec.LookPath(pgBin); err != nil {
|
||||
return fmt.Errorf("PostgreSQL executable not found or not accessible: %s - %w", pgBin, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
|
||||
metadata := BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
|
||||
return storageWriter, nil, metadata, nil
|
||||
}
|
||||
|
||||
salt, err := backup_encryption.GenerateSalt()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := backup_encryption.GenerateNonce()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyRepo.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
||||
storageWriter,
|
||||
masterKey,
|
||||
backupID,
|
||||
salt,
|
||||
nonce,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
||||
}
|
||||
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
metadata.EncryptionSalt = &saltBase64
|
||||
metadata.EncryptionIV = &nonceBase64
|
||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||
|
||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||
return encWriter, encWriter, metadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
saveErrCh chan error,
|
||||
) {
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
if closeErr := encryptionWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close encrypting writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
|
||||
}
|
||||
|
||||
<-saveErrCh
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
) error {
|
||||
encryptionCloseErrCh := make(chan error, 1)
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
closeErr := encryptionWriter.Close()
|
||||
if closeErr != nil {
|
||||
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
|
||||
}
|
||||
encryptionCloseErrCh <- closeErr
|
||||
}()
|
||||
} else {
|
||||
encryptionCloseErrCh <- nil
|
||||
}
|
||||
|
||||
encryptionCloseErr := <-encryptionCloseErrCh
|
||||
if encryptionCloseErr != nil {
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
|
||||
}
|
||||
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) checkCancellation(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) checkCancellationReason() error {
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
pgBin string,
|
||||
args []string,
|
||||
password string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf("%s failed: %v – stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
if exitCode == exitCodeGenericError && strings.TrimSpace(stderrStr) == "" {
|
||||
return uc.handleExitCode1NoStderr(pgBin, args)
|
||||
}
|
||||
|
||||
if exitCode == exitCodeAccessViolation {
|
||||
return uc.handleAccessViolation(pgBin, stderrStr)
|
||||
}
|
||||
|
||||
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
|
||||
return uc.handleConnectionErrors(stderrStr, password)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleExitCode1NoStderr(
|
||||
pgBin string,
|
||||
args []string,
|
||||
) error {
|
||||
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"env_vars", []string{
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT=" + strconv.Itoa(pgConnectTimeout),
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
},
|
||||
)
|
||||
|
||||
return fmt.Errorf(
|
||||
"%s failed with exit status 1 but provided no error details. "+
|
||||
"This often indicates: "+
|
||||
"1) Connection timeout or refused connection, "+
|
||||
"2) Authentication failure with incorrect credentials, "+
|
||||
"3) Database does not exist, "+
|
||||
"4) Network connectivity issues, "+
|
||||
"5) PostgreSQL server not running. "+
|
||||
"Command executed: %s %s",
|
||||
filepath.Base(pgBin),
|
||||
pgBin,
|
||||
strings.Join(args, " "),
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleAccessViolation(
|
||||
pgBin string,
|
||||
stderrStr string,
|
||||
) error {
|
||||
uc.logger.Error("PostgreSQL tool crashed with access violation",
|
||||
"pgBin", pgBin,
|
||||
"exitCode", "0xC0000005",
|
||||
)
|
||||
|
||||
return fmt.Errorf(
|
||||
"%s crashed with access violation (0xC0000005). "+
|
||||
"This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. "+
|
||||
"stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleConnectionErrors(
|
||||
stderrStr string,
|
||||
password string,
|
||||
) error {
|
||||
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection rejected by server configuration (pg_hba.conf). "+
|
||||
"The server may not allow connections from your IP address or may require different authentication settings. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "no password supplied") ||
|
||||
containsIgnoreCase(stderrStr, "fe_sendauth") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL authentication failed - no password supplied. "+
|
||||
"PGPASSWORD environment variable may not be working correctly on this system. "+
|
||||
"Password length: %d, Password empty: %v. "+
|
||||
"Consider using a .pgpass file as an alternative. "+
|
||||
"stderr: %s",
|
||||
len(password),
|
||||
password == "",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL SSL connection failed. "+
|
||||
"The server may require SSL encryption or have SSL configuration issues. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection refused. "+
|
||||
"Check if the server is running and accessible from your network. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "authentication") ||
|
||||
containsIgnoreCase(stderrStr, "password") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL authentication failed. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("PostgreSQL connection or authentication error. stderr: %s", stderrStr)
|
||||
}
|
||||
|
||||
// createTempPgpassFile creates a temporary .pgpass file with the given password
|
||||
func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
password string,
|
||||
@@ -508,7 +727,6 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
password,
|
||||
)
|
||||
|
||||
// it always create unique directory like /tmp/pgpass-1234567890
|
||||
tempDir, err := os.MkdirTemp("", "pgpass")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
@@ -522,3 +740,7 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
|
||||
return pgpassFile, nil
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import (
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import backups_config "postgresus-backend/internal/features/backups/config"
|
||||
|
||||
type EncryptionMetadata struct {
|
||||
Salt string
|
||||
IV string
|
||||
Encryption backups_config.BackupEncryption
|
||||
}
|
||||
|
||||
type BackupMetadata struct {
|
||||
EncryptionSalt *string
|
||||
EncryptionIV *string
|
||||
Encryption backups_config.BackupEncryption
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package backups_config
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type BackupConfigController struct {
|
||||
backupConfigService *BackupConfigService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -21,35 +20,29 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
|
||||
// SaveBackupConfig
|
||||
// @Summary Save backup configuration
|
||||
// @Description Save or update backup configuration for a database
|
||||
// @Description Save or update backup configuration for a database. Encryption can be set to NONE (no encryption) or ENCRYPTED (AES-256-GCM encryption).
|
||||
// @Tags backup-configs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body BackupConfig true "Backup configuration data"
|
||||
// @Success 200 {object} BackupConfig
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Param request body BackupConfig true "Backup configuration data (encryption field: NONE or ENCRYPTED)"
|
||||
// @Success 200 {object} BackupConfig "Returns the saved backup configuration including encryption settings"
|
||||
// @Failure 400 {object} map[string]string "Invalid encryption value or other validation errors"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 500 {object} map[string]string "Internal server error"
|
||||
// @Router /backup-configs/save [post]
|
||||
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var requestDTO BackupConfig
|
||||
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
// make sure we rely on full .Storage object
|
||||
requestDTO.StorageID = nil
|
||||
|
||||
@@ -64,40 +57,28 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
|
||||
|
||||
// GetBackupConfigByDbID
|
||||
// @Summary Get backup configuration by database ID
|
||||
// @Description Get backup configuration for a specific database
|
||||
// @Description Get backup configuration for a specific database including encryption settings (NONE or ENCRYPTED)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} BackupConfig
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 404
|
||||
// @Success 200 {object} BackupConfig "Returns backup configuration with encryption field"
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Backup configuration not found"
|
||||
// @Router /backup-configs/database/{id} [get]
|
||||
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := c.backupConfigService.GetBackupConfigByDbIdWithAuth(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "backup configuration not found"})
|
||||
@@ -119,24 +100,18 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backup-configs/storage/{id}/is-using [get]
|
||||
func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
isUsing, err := c.backupConfigService.IsStorageUsing(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
493
backend/internal/features/backups/config/controller_test.go
Normal file
493
backend/internal/features/backups/config/controller_test.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/period"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetBackupConfigController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can save backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
testResp := test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, 2, response.CpuCount)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+nonMember.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
testResp := test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "backup configuration not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, 1, response.CpuCount)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isStorageOwner bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "storage owner can check storage usage",
|
||||
isStorageOwner: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-storage-owner cannot check storage usage",
|
||||
isStorageOwner: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
storageOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"Test Workspace",
|
||||
storageOwner,
|
||||
router,
|
||||
)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isStorageOwner {
|
||||
testUserToken = storageOwner.Token
|
||||
} else {
|
||||
otherUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = otherUser.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response map[string]bool
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
isUsing, exists := response["isUsing"]
|
||||
assert.True(t, exists)
|
||||
assert.False(t, isUsing)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.Equal(t, BackupEncryptionNone, response.Encryption)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionEncrypted,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.Equal(t, BackupEncryptionEncrypted, response.Encryption)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
return storages.CreateTestStorage(workspaceID)
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package backups_config
|
||||
import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -11,11 +11,11 @@ var backupConfigService = &BackupConfigService{
|
||||
backupConfigRepository,
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
}
|
||||
var backupConfigController = &BackupConfigController{
|
||||
backupConfigService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetBackupConfigController() *BackupConfigController {
|
||||
|
||||
@@ -6,3 +6,10 @@ const (
|
||||
NotificationBackupFailed BackupNotificationType = "BACKUP_FAILED"
|
||||
NotificationBackupSuccess BackupNotificationType = "BACKUP_SUCCESS"
|
||||
)
|
||||
|
||||
type BackupEncryption string
|
||||
|
||||
const (
|
||||
BackupEncryptionNone BackupEncryption = "NONE"
|
||||
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
CpuCount int `json:"cpuCount" gorm:"type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -88,6 +90,11 @@ func (b *BackupConfig) Validate() error {
|
||||
return errors.New("max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
if b.Encryption != "" && b.Encryption != BackupEncryptionNone &&
|
||||
b.Encryption != BackupEncryptionEncrypted {
|
||||
return errors.New("encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -103,5 +110,6 @@ func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
CpuCount: b.CpuCount,
|
||||
Encryption: b.Encryption,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,6 +17,7 @@ type BackupConfigService struct {
|
||||
backupConfigRepository *BackupConfigRepository
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -32,11 +36,23 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
|
||||
database, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot save backup config for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to modify backup configuration")
|
||||
}
|
||||
|
||||
return s.SaveBackupConfig(backupConfig)
|
||||
}
|
||||
|
||||
@@ -66,19 +82,6 @@ func (s *BackupConfigService) SaveBackupConfig(
|
||||
}
|
||||
}
|
||||
|
||||
if !backupConfig.IsBackupsEnabled && existingConfig.StorageID != nil {
|
||||
if err := s.dbStorageChangeListener.OnBeforeBackupsStorageChange(
|
||||
backupConfig.DatabaseID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we clear storage for disabled backups to allow
|
||||
// storage removal for unused storages
|
||||
backupConfig.Storage = nil
|
||||
backupConfig.StorageID = nil
|
||||
}
|
||||
|
||||
return s.backupConfigRepository.Save(backupConfig)
|
||||
}
|
||||
|
||||
@@ -144,6 +147,10 @@ func (s *BackupConfigService) OnDatabaseCopied(originalDatabaseID, newDatabaseID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) error {
|
||||
return s.initializeDefaultConfig(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
@@ -164,6 +171,7 @@ func (s *BackupConfigService) initializeDefaultConfig(
|
||||
CpuCount: 1,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
|
||||
@@ -2,15 +2,18 @@ package databases
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DatabaseController struct {
|
||||
databaseService *DatabaseService
|
||||
userService *users.UserService
|
||||
databaseService *DatabaseService
|
||||
userService *users_services.UserService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -28,36 +31,35 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
|
||||
// CreateDatabase
|
||||
// @Summary Create a new database
|
||||
// @Description Create a new database configuration
|
||||
// @Description Create a new database configuration in a workspace
|
||||
// @Tags databases
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body Database true "Database creation data"
|
||||
// @Param request body Database true "Database creation data with workspaceId"
|
||||
// @Success 201 {object} Database
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases/create [post]
|
||||
func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
if request.WorkspaceID == nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
database, err := c.databaseService.CreateDatabase(user, &request)
|
||||
database, err := c.databaseService.CreateDatabase(user, *request.WorkspaceID, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -79,24 +81,18 @@ func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/update [post]
|
||||
func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.UpdateDatabase(user, &request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -116,24 +112,18 @@ func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/{id} [delete]
|
||||
func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.DeleteDatabase(user, id); err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -153,24 +143,18 @@ func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /databases/{id} [get]
|
||||
func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
database, err := c.databaseService.GetDatabase(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -181,30 +165,38 @@ func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// GetDatabases
|
||||
// @Summary Get databases
|
||||
// @Description Get all databases for the authenticated user
|
||||
// @Summary Get databases by workspace
|
||||
// @Description Get all databases for a specific workspace
|
||||
// @Tags databases
|
||||
// @Produce json
|
||||
// @Param workspace_id query string true "Workspace ID"
|
||||
// @Success 200 {array} Database
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases [get]
|
||||
func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
workspaceIDStr := ctx.Query("workspace_id")
|
||||
if workspaceIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
databases, err := c.databaseService.GetDatabasesByUser(user)
|
||||
workspaceID, err := uuid.Parse(workspaceIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
|
||||
return
|
||||
}
|
||||
|
||||
databases, err := c.databaseService.GetDatabasesByWorkspace(user, workspaceID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -222,24 +214,18 @@ func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/{id}/test-connection [post]
|
||||
func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.TestDatabaseConnection(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -259,27 +245,18 @@ func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /databases/test-connection-direct [post]
|
||||
func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
|
||||
_, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Set user ID for validation purposes
|
||||
request.UserID = user.ID
|
||||
|
||||
if err := c.databaseService.TestDatabaseConnectionDirect(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -300,24 +277,18 @@ func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/notifier/{id}/is-using [get]
|
||||
func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
isUsing, err := c.databaseService.IsNotifierUsing(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -339,24 +310,18 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/{id}/copy [post]
|
||||
func (c *DatabaseController) CopyDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
copiedDatabase, err := c.databaseService.CopyDatabase(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
|
||||
1002
backend/internal/features/databases/controller_test.go
Normal file
1002
backend/internal/features/databases/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,9 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
"regexp"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -59,11 +59,51 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
|
||||
func (p *PostgresqlDatabase) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return testSingleDatabaseConnection(logger, ctx, p)
|
||||
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) HideSensitiveData() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.Password = ""
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
||||
p.Version = incoming.Version
|
||||
p.Host = incoming.Host
|
||||
p.Port = incoming.Port
|
||||
p.Username = incoming.Username
|
||||
p.Database = incoming.Database
|
||||
p.IsHttps = incoming.IsHttps
|
||||
|
||||
if incoming.Password != "" {
|
||||
p.Password = incoming.Password
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
databaseID uuid.UUID,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if p.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Password = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
|
||||
@@ -71,14 +111,22 @@ func testSingleDatabaseConnection(
|
||||
logger *slog.Logger,
|
||||
ctx context.Context,
|
||||
postgresDb *PostgresqlDatabase,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
// For single database backup, we need to connect to the specific database
|
||||
if postgresDb.Database == nil || *postgresDb.Database == "" {
|
||||
return errors.New("database name is required for single database backup (pg_dump)")
|
||||
}
|
||||
|
||||
// Decrypt password if needed
|
||||
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Build connection string for the specific database
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
|
||||
|
||||
// Test connection
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
@@ -161,7 +209,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
|
||||
}
|
||||
|
||||
// buildConnectionStringForDB builds connection string for specific database
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
|
||||
sslMode := "disable"
|
||||
if p.IsHttps {
|
||||
sslMode = "require"
|
||||
@@ -171,106 +219,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
p.Host,
|
||||
p.Port,
|
||||
p.Username,
|
||||
p.Password,
|
||||
password,
|
||||
dbName,
|
||||
sslMode,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) InstallExtensions(extensions []tools.PostgresqlExtension) error {
|
||||
if len(extensions) == 0 {
|
||||
return nil
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, error) {
|
||||
if encryptor == nil {
|
||||
return password, nil
|
||||
}
|
||||
|
||||
if p.Database == nil || *p.Database == "" {
|
||||
return errors.New("database name is required for installing extensions")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Build connection string for the specific database
|
||||
connStr := buildConnectionStringForDB(p, *p.Database)
|
||||
|
||||
// Connect to database
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database '%s': %w", *p.Database, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
fmt.Println("failed to close connection: %w", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Check which extensions are already installed
|
||||
installedExtensions, err := p.getInstalledExtensions(ctx, conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check installed extensions: %w", err)
|
||||
}
|
||||
|
||||
// Install missing extensions
|
||||
for _, extension := range extensions {
|
||||
if contains(installedExtensions, string(extension)) {
|
||||
continue // Extension already installed
|
||||
}
|
||||
|
||||
if err := p.installExtension(ctx, conn, string(extension)); err != nil {
|
||||
return fmt.Errorf("failed to install extension '%s': %w", extension, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getInstalledExtensions queries the database for currently installed extensions
|
||||
func (p *PostgresqlDatabase) getInstalledExtensions(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
) ([]string, error) {
|
||||
query := "SELECT extname FROM pg_extension"
|
||||
|
||||
rows, err := conn.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query installed extensions: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var extensions []string
|
||||
for rows.Next() {
|
||||
var extname string
|
||||
|
||||
if err := rows.Scan(&extname); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan extension name: %w", err)
|
||||
}
|
||||
|
||||
extensions = append(extensions, extname)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating over extension rows: %w", err)
|
||||
}
|
||||
|
||||
return extensions, nil
|
||||
}
|
||||
|
||||
// installExtension installs a single PostgreSQL extension
|
||||
func (p *PostgresqlDatabase) installExtension(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
extensionName string,
|
||||
) error {
|
||||
query := fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s", extensionName)
|
||||
|
||||
_, err := conn.Exec(ctx, query)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute CREATE EXTENSION: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// contains checks if a string slice contains a specific string
|
||||
func contains(slice []string, item string) bool {
|
||||
return slices.Contains(slice, item)
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -15,11 +18,15 @@ var databaseService = &DatabaseService{
|
||||
[]DatabaseCreationListener{},
|
||||
[]DatabaseRemoveListener{},
|
||||
[]DatabaseCopyListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
var databaseController = &DatabaseController{
|
||||
databaseService,
|
||||
users.GetUserService(),
|
||||
users_services.GetUserService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
func GetDatabaseService() *DatabaseService {
|
||||
@@ -29,3 +36,7 @@ func GetDatabaseService() *DatabaseService {
|
||||
func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package databases
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -11,7 +12,13 @@ type DatabaseValidator interface {
|
||||
}
|
||||
|
||||
type DatabaseConnector interface {
|
||||
TestConnection(logger *slog.Logger) error
|
||||
TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error
|
||||
|
||||
HideSensitiveData()
|
||||
}
|
||||
|
||||
type DatabaseCreationListener interface {
|
||||
|
||||
@@ -5,16 +5,20 @@ import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
UserID uuid.UUID `json:"userId" gorm:"column:user_id;type:uuid;not null"`
|
||||
Name string `json:"name" gorm:"column:name;type:text;not null"`
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
|
||||
// WorkspaceID can be null when a database is created via restore operation
|
||||
// outside the context of any workspace
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;type:uuid"`
|
||||
Name string `json:"name" gorm:"column:name;type:text;not null"`
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
|
||||
@@ -53,8 +57,35 @@ func (d *Database) ValidateUpdate(old, new Database) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) TestConnection(logger *slog.Logger) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger)
|
||||
func (d *Database) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
|
||||
}
|
||||
|
||||
func (d *Database) HideSensitiveData() {
|
||||
d.getSpecificDatabase().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Update(incoming *Database) {
|
||||
d.Name = incoming.Name
|
||||
d.Type = incoming.Type
|
||||
d.Notifiers = incoming.Notifiers
|
||||
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if d.Postgresql != nil && incoming.Postgresql != nil {
|
||||
d.Postgresql.Update(incoming.Postgresql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) getSpecificDatabase() DatabaseConnector {
|
||||
|
||||
@@ -92,14 +92,14 @@ func (r *DatabaseRepository) FindByID(id uuid.UUID) (*Database, error) {
|
||||
return &database, nil
|
||||
}
|
||||
|
||||
func (r *DatabaseRepository) FindByUserID(userID uuid.UUID) ([]*Database, error) {
|
||||
func (r *DatabaseRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Database, error) {
|
||||
var databases []*Database
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Postgresql").
|
||||
Preload("Notifiers").
|
||||
Where("user_id = ?", userID).
|
||||
Where("workspace_id = ?", workspaceID).
|
||||
Order("CASE WHEN health_status = 'UNAVAILABLE' THEN 1 WHEN health_status = 'AVAILABLE' THEN 2 WHEN health_status IS NULL THEN 3 ELSE 4 END, name ASC").
|
||||
Find(&databases).Error; err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -2,11 +2,16 @@ package databases
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
"time"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -19,6 +24,10 @@ type DatabaseService struct {
|
||||
dbCreationListener []DatabaseCreationListener
|
||||
dbRemoveListener []DatabaseRemoveListener
|
||||
dbCopyListener []DatabaseCopyListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *DatabaseService) AddDbCreationListener(
|
||||
@@ -41,15 +50,28 @@ func (s *DatabaseService) AddDbCopyListener(
|
||||
|
||||
func (s *DatabaseService) CreateDatabase(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
database *Database,
|
||||
) (*Database, error) {
|
||||
database.UserID = user.ID
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to create database in this workspace")
|
||||
}
|
||||
|
||||
database.WorkspaceID = &workspaceID
|
||||
|
||||
if err := database.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database, err := s.dbRepository.Save(database)
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
database, err = s.dbRepository.Save(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -58,6 +80,12 @@ func (s *DatabaseService) CreateDatabase(
|
||||
listener.OnDatabaseCreated(database.ID)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database created: %s", database.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
@@ -74,24 +102,43 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if existingDatabase.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return errors.New("cannot update database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to update this database")
|
||||
}
|
||||
|
||||
// Validate the update
|
||||
if err := database.ValidateUpdate(*existingDatabase, *database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := database.Validate(); err != nil {
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.dbRepository.Save(database)
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.dbRepository.Save(existingDatabase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -104,8 +151,16 @@ func (s *DatabaseService) DeleteDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if existingDatabase.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return errors.New("cannot delete database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to delete this database")
|
||||
}
|
||||
|
||||
for _, listener := range s.dbRemoveListener {
|
||||
@@ -114,6 +169,12 @@ func (s *DatabaseService) DeleteDatabase(
|
||||
}
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database deleted: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.dbRepository.Delete(id)
|
||||
}
|
||||
|
||||
@@ -126,17 +187,44 @@ func (s *DatabaseService) GetDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("you have not access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access this database")
|
||||
}
|
||||
|
||||
database.HideSensitiveData()
|
||||
return database, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetDatabasesByUser(
|
||||
func (s *DatabaseService) GetDatabasesByWorkspace(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
) ([]*Database, error) {
|
||||
return s.dbRepository.FindByUserID(user.ID)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access this workspace")
|
||||
}
|
||||
|
||||
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, database := range databases {
|
||||
database.HideSensitiveData()
|
||||
}
|
||||
|
||||
return databases, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) IsNotifierUsing(
|
||||
@@ -160,11 +248,19 @@ func (s *DatabaseService) TestDatabaseConnection(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot test connection for database without workspace")
|
||||
}
|
||||
|
||||
err = database.TestConnection(s.logger)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canAccess {
|
||||
return errors.New("insufficient permissions to test connection for this database")
|
||||
}
|
||||
|
||||
err = database.TestConnection(s.logger, s.fieldEncryptor)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
database.LastBackupErrorMessage = &lastSaveError
|
||||
@@ -184,7 +280,31 @@ func (s *DatabaseService) TestDatabaseConnection(
|
||||
func (s *DatabaseService) TestDatabaseConnectionDirect(
|
||||
database *Database,
|
||||
) error {
|
||||
return database.TestConnection(s.logger)
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && existingDatabase.WorkspaceID != nil &&
|
||||
*existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
} else {
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetDatabaseByID(
|
||||
@@ -237,13 +357,21 @@ func (s *DatabaseService) CopyDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existingDatabase.UserID != user.ID {
|
||||
return nil, errors.New("you have not access to this database")
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot copy database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to copy this database")
|
||||
}
|
||||
|
||||
newDatabase := &Database{
|
||||
ID: uuid.Nil,
|
||||
UserID: user.ID,
|
||||
WorkspaceID: existingDatabase.WorkspaceID,
|
||||
Name: existingDatabase.Name + " (Copy)",
|
||||
Type: existingDatabase.Type,
|
||||
Notifiers: existingDatabase.Notifiers,
|
||||
@@ -286,6 +414,12 @@ func (s *DatabaseService) CopyDatabase(
|
||||
listener.OnDatabaseCopied(databaseID, copiedDatabase.ID)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database copied: %s to %s", existingDatabase.Name, copiedDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return copiedDatabase, nil
|
||||
}
|
||||
|
||||
@@ -306,3 +440,19 @@ func (s *DatabaseService) SetHealthStatus(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(databases) > 0 {
|
||||
return fmt.Errorf(
|
||||
"workspace contains %d databases that must be deleted",
|
||||
len(databases),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
)
|
||||
|
||||
func CreateTestDatabase(
|
||||
userID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
storage *storages.Storage,
|
||||
notifier *notifiers.Notifier,
|
||||
) *Database {
|
||||
database := &Database{
|
||||
UserID: userID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
|
||||
@@ -10,23 +10,34 @@ import (
|
||||
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
// Create workspace directly via service
|
||||
workspace, err := workspaces_testing.CreateTestWorkspaceDirect("Test Workspace", user.UserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
defer storages.RemoveTestStorage(storage.ID)
|
||||
defer notifiers.RemoveTestNotifier(notifier)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspaceDirect(workspace.ID)
|
||||
}()
|
||||
|
||||
t.Run("Test_DbAttemptFailed_DbMarkedAsUnavailable", func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
@@ -94,7 +105,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbNotMarkerdAsDownAfterFirstAttempt",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
@@ -160,7 +171,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbMarkerdAsDownAfterThirdFailedAttempt",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Make sure DB is available
|
||||
@@ -237,7 +248,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
)
|
||||
|
||||
t.Run("Test_UnavailableDbAttemptSucceed_DbMarkedAsAvailable", func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Make sure DB is unavailable
|
||||
@@ -311,7 +322,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbHealthcheckExecutedFast_HealthcheckNotExecutedFasterThanInterval",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
|
||||
@@ -2,7 +2,7 @@ package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
type HealthcheckAttemptController struct {
|
||||
healthcheckAttemptService *HealthcheckAttemptService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -31,9 +30,9 @@ func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-attempts/{databaseId} [get]
|
||||
func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -43,7 +42,7 @@ func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
afterDate := time.Now().UTC()
|
||||
afterDate := time.Now().UTC().Add(-7 * 24 * time.Hour)
|
||||
if afterDateStr := ctx.Query("afterDate"); afterDateStr != "" {
|
||||
parsedDate, err := time.Parse(time.RFC3339, afterDateStr)
|
||||
if err != nil {
|
||||
|
||||
261
backend/internal/features/healthcheck/attempt/controller_test.go
Normal file
261
backend/internal/features/healthcheck/attempt/controller_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetHealthcheckAttemptController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get healthcheck attempts",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get healthcheck attempts",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
pastTime := time.Now().UTC().Add(-1 * time.Hour)
|
||||
createTestHealthcheckAttemptWithTime(
|
||||
database.ID,
|
||||
databases.HealthStatusAvailable,
|
||||
pastTime,
|
||||
)
|
||||
createTestHealthcheckAttemptWithTime(
|
||||
database.ID,
|
||||
databases.HealthStatusUnavailable,
|
||||
pastTime.Add(-30*time.Minute),
|
||||
)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response), 2)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "forbidden")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
oldTime := time.Now().UTC().Add(-2 * time.Hour)
|
||||
recentTime := time.Now().UTC().Add(-30 * time.Minute)
|
||||
|
||||
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusAvailable, oldTime)
|
||||
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusUnavailable, recentTime)
|
||||
createTestHealthcheckAttempt(database.ID, databases.HealthStatusAvailable)
|
||||
|
||||
afterDate := time.Now().UTC().Add(-1 * time.Hour)
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/healthcheck-attempts/%s?afterDate=%s",
|
||||
database.ID.String(),
|
||||
afterDate.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, 2, len(response))
|
||||
for _, attempt := range response {
|
||||
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, 0, len(response))
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestHealthcheckAttempt(databaseID uuid.UUID, status databases.HealthStatus) {
|
||||
createTestHealthcheckAttemptWithTime(databaseID, status, time.Now().UTC())
|
||||
}
|
||||
|
||||
func createTestHealthcheckAttemptWithTime(
|
||||
databaseID uuid.UUID,
|
||||
status databases.HealthStatus,
|
||||
createdAt time.Time,
|
||||
) {
|
||||
repo := GetHealthcheckAttemptRepository()
|
||||
attempt := &HealthcheckAttempt{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
Status: status,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
if err := repo.Create(attempt); err != nil {
|
||||
panic("Failed to create test healthcheck attempt: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ var healthcheckAttemptRepository = &HealthcheckAttemptRepository{}
|
||||
var healthcheckAttemptService = &HealthcheckAttemptService{
|
||||
healthcheckAttemptRepository,
|
||||
databases.GetDatabaseService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
var checkPgHealthUseCase = &CheckPgHealthUseCase{
|
||||
@@ -27,7 +28,10 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetHealthcheckAttemptRepository() *HealthcheckAttemptRepository {
|
||||
return healthcheckAttemptRepository
|
||||
}
|
||||
|
||||
func GetHealthcheckAttemptService() *HealthcheckAttemptService {
|
||||
|
||||
@@ -53,7 +53,7 @@ func (r *HealthcheckAttemptRepository) DeleteOlderThan(
|
||||
Delete(&HealthcheckAttempt{}).Error
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) Insert(
|
||||
func (r *HealthcheckAttemptRepository) Create(
|
||||
attempt *HealthcheckAttempt,
|
||||
) error {
|
||||
if attempt.ID == uuid.Nil {
|
||||
@@ -67,6 +67,12 @@ func (r *HealthcheckAttemptRepository) Insert(
|
||||
return storage.GetDb().Create(attempt).Error
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) Insert(
|
||||
attempt *HealthcheckAttempt,
|
||||
) error {
|
||||
return r.Create(attempt)
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) FindByDatabaseIDWithLimit(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
type HealthcheckAttemptService struct {
|
||||
healthcheckAttemptRepository *HealthcheckAttemptRepository
|
||||
databaseService *databases.DatabaseService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
|
||||
@@ -24,7 +26,15 @@ func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access healthcheck attempts for databases without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("forbidden")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package healthcheck_config
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type HealthcheckConfigController struct {
|
||||
healthcheckConfigService *HealthcheckConfigService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -31,9 +30,9 @@ func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-config [post]
|
||||
func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,9 +64,9 @@ func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-config/{databaseId} [get]
|
||||
func (c *HealthcheckConfigController) GetHealthcheckConfig(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
328
backend/internal/features/healthcheck/config/controller_test.go
Normal file
328
backend/internal/features/healthcheck/config/controller_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetHealthcheckConfigController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can save healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
request := HealthcheckConfigDTO{
|
||||
DatabaseID: database.ID,
|
||||
IsHealthcheckEnabled: true,
|
||||
IsSentNotificationWhenUnavailable: true,
|
||||
IntervalMinutes: 5,
|
||||
AttemptsBeforeConcideredAsDown: 3,
|
||||
StoreAttemptsDays: 7,
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response map[string]string
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
assert.Contains(t, response["message"], "successfully")
|
||||
} else {
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := HealthcheckConfigDTO{
|
||||
DatabaseID: database.ID,
|
||||
IsHealthcheckEnabled: true,
|
||||
IsSentNotificationWhenUnavailable: true,
|
||||
IntervalMinutes: 5,
|
||||
AttemptsBeforeConcideredAsDown: 3,
|
||||
StoreAttemptsDays: 7,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+nonMember.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response HealthcheckConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsHealthcheckEnabled)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response HealthcheckConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsHealthcheckEnabled)
|
||||
assert.True(t, response.IsSentNotificationWhenUnavailable)
|
||||
assert.Equal(t, 1, response.IntervalMinutes)
|
||||
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
|
||||
assert.Equal(t, 7, response.StoreAttemptsDays)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -10,11 +11,12 @@ var healthcheckConfigRepository = &HealthcheckConfigRepository{}
|
||||
var healthcheckConfigService = &HealthcheckConfigService{
|
||||
databases.GetDatabaseService(),
|
||||
healthcheckConfigRepository,
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var healthcheckConfigController = &HealthcheckConfigController{
|
||||
healthcheckConfigService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetHealthcheckConfigService() *HealthcheckConfigService {
|
||||
|
||||
@@ -2,9 +2,12 @@ package healthcheck_config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -12,6 +15,8 @@ import (
|
||||
type HealthcheckConfigService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
healthcheckConfigRepository *HealthcheckConfigRepository
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
@@ -33,8 +38,16 @@ func (s *HealthcheckConfigService) Save(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot modify healthcheck config for databases without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to modify healthcheck config")
|
||||
}
|
||||
|
||||
healthcheckConfig := configDTO.ToDTO()
|
||||
@@ -60,6 +73,12 @@ func (s *HealthcheckConfigService) Save(
|
||||
}
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Healthcheck config updated for database '%s'", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -72,8 +91,16 @@ func (s *HealthcheckConfigService) GetByDatabaseID(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access healthcheck config for databases without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to view healthcheck config")
|
||||
}
|
||||
|
||||
config, err := s.healthcheckConfigRepository.GetByDatabaseID(database.ID)
|
||||
|
||||
@@ -2,15 +2,16 @@ package notifiers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotifierController struct {
|
||||
notifierService *NotifierService
|
||||
userService *users.UserService
|
||||
notifierService *NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -29,35 +30,40 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param notifier body Notifier true "Notifier data"
|
||||
// @Param request body Notifier true "Notifier data with workspaceId"
|
||||
// @Success 200 {object} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers [post]
|
||||
func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var notifier Notifier
|
||||
if err := ctx.ShouldBindJSON(¬ifier); err != nil {
|
||||
var request Notifier
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := notifier.Validate(); err != nil {
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SaveNotifier(user, ¬ifier); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, notifier)
|
||||
ctx.JSON(http.StatusOK, request)
|
||||
}
|
||||
|
||||
// GetNotifier
|
||||
@@ -70,11 +76,12 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
|
||||
// @Success 200 {object} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id} [get]
|
||||
func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,6 +93,10 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
|
||||
notifier, err := c.notifierService.GetNotifier(user, id)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -95,22 +106,41 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
|
||||
// GetNotifiers
|
||||
// @Summary Get all notifiers
|
||||
// @Description Get all notifiers for the current user
|
||||
// @Description Get all notifiers for a workspace
|
||||
// @Tags notifiers
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param workspace_id query string true "Workspace ID"
|
||||
// @Success 200 {array} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers [get]
|
||||
func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
notifiers, err := c.notifierService.GetNotifiers(user)
|
||||
workspaceIDStr := ctx.Query("workspace_id")
|
||||
if workspaceIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
workspaceID, err := uuid.Parse(workspaceIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
|
||||
return
|
||||
}
|
||||
|
||||
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -128,11 +158,12 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id} [delete]
|
||||
func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,13 +173,11 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
notifier, err := c.notifierService.GetNotifier(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.DeleteNotifier(user, notifier.ID); err != nil {
|
||||
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -166,11 +195,12 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id}/test [post]
|
||||
func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -181,6 +211,10 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotification(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to test notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -195,28 +229,44 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param notifier body Notifier true "Notifier data"
|
||||
// @Param request body Notifier true "Notifier data with workspaceId"
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/direct-test [post]
|
||||
func (c *NotifierController) SendTestNotificationDirect(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var notifier Notifier
|
||||
if err := ctx.ShouldBindJSON(¬ifier); err != nil {
|
||||
var request Notifier
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// For direct test, associate with the current user
|
||||
notifier.UserID = user.ID
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotificationToNotifier(¬ifier); err != nil {
|
||||
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !canView {
|
||||
ctx.JSON(
|
||||
http.StatusForbidden,
|
||||
gin.H{"error": "insufficient permissions to test notifier in this workspace"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotificationToNotifier(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
1041
backend/internal/features/notifiers/controller_test.go
Normal file
1041
backend/internal/features/notifiers/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/users"
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -9,10 +11,13 @@ var notifierRepository = &NotifierRepository{}
|
||||
var notifierService = &NotifierService{
|
||||
notifierRepository,
|
||||
logger.GetLogger(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var notifierController = &NotifierController{
|
||||
notifierService,
|
||||
users.GetUserService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
func GetNotifierController() *NotifierController {
|
||||
@@ -22,3 +27,10 @@ func GetNotifierController() *NotifierController {
|
||||
func GetNotifierService() *NotifierService {
|
||||
return notifierService
|
||||
}
|
||||
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
}
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
package notifiers
|
||||
|
||||
import "log/slog"
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
Send(logger *slog.Logger, heading string, message string) error
|
||||
Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error
|
||||
|
||||
Validate() error
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
HideSensitiveData()
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
@@ -9,13 +9,14 @@ import (
|
||||
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
|
||||
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
|
||||
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null;type:uuid;index"`
|
||||
WorkspaceID uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;not null;type:uuid;index"`
|
||||
Name string `json:"name" gorm:"column:name;not null;type:varchar(255)"`
|
||||
NotifierType NotifierType `json:"notifierType" gorm:"column:notifier_type;not null;type:varchar(50)"`
|
||||
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
|
||||
@@ -33,16 +34,21 @@ func (n *Notifier) TableName() string {
|
||||
return "notifiers"
|
||||
}
|
||||
|
||||
func (n *Notifier) Validate() error {
|
||||
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
|
||||
return n.getSpecificNotifier().Validate()
|
||||
return n.getSpecificNotifier().Validate(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
err := n.getSpecificNotifier().Send(logger, heading, message)
|
||||
func (n *Notifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
|
||||
|
||||
if err != nil {
|
||||
lastSendError := err.Error()
|
||||
@@ -54,6 +60,46 @@ func (n *Notifier) Send(logger *slog.Logger, heading string, message string) err
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *Notifier) HideSensitiveData() {
|
||||
n.getSpecificNotifier().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Update(incoming *Notifier) {
|
||||
n.Name = incoming.Name
|
||||
n.NotifierType = incoming.NotifierType
|
||||
|
||||
switch n.NotifierType {
|
||||
case NotifierTypeTelegram:
|
||||
if n.TelegramNotifier != nil && incoming.TelegramNotifier != nil {
|
||||
n.TelegramNotifier.Update(incoming.TelegramNotifier)
|
||||
}
|
||||
case NotifierTypeEmail:
|
||||
if n.EmailNotifier != nil && incoming.EmailNotifier != nil {
|
||||
n.EmailNotifier.Update(incoming.EmailNotifier)
|
||||
}
|
||||
case NotifierTypeWebhook:
|
||||
if n.WebhookNotifier != nil && incoming.WebhookNotifier != nil {
|
||||
n.WebhookNotifier.Update(incoming.WebhookNotifier)
|
||||
}
|
||||
case NotifierTypeSlack:
|
||||
if n.SlackNotifier != nil && incoming.SlackNotifier != nil {
|
||||
n.SlackNotifier.Update(incoming.SlackNotifier)
|
||||
}
|
||||
case NotifierTypeDiscord:
|
||||
if n.DiscordNotifier != nil && incoming.DiscordNotifier != nil {
|
||||
n.DiscordNotifier.Update(incoming.DiscordNotifier)
|
||||
}
|
||||
case NotifierTypeTeams:
|
||||
if n.TeamsNotifier != nil && incoming.TeamsNotifier != nil {
|
||||
n.TeamsNotifier.Update(incoming.TeamsNotifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) getSpecificNotifier() NotificationSender {
|
||||
switch n.NotifierType {
|
||||
case NotifierTypeTelegram:
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
|
||||
return "discord_notifiers"
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Validate() error {
|
||||
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (d *DiscordNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal Discord payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
|
||||
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -71,3 +82,24 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) HideSensitiveData() {
|
||||
d.ChannelWebhookURL = ""
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
|
||||
if incoming.ChannelWebhookURL != "" {
|
||||
d.ChannelWebhookURL = incoming.ChannelWebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
d.ChannelWebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -27,13 +28,14 @@ type EmailNotifier struct {
|
||||
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
|
||||
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
|
||||
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
|
||||
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) TableName() string {
|
||||
return "email_notifiers"
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Validate() error {
|
||||
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if e.TargetEmail == "" {
|
||||
return errors.New("target email is required")
|
||||
}
|
||||
@@ -54,11 +56,29 @@ func (e *EmailNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (e *EmailNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
// Decrypt SMTP password if provided
|
||||
var smtpPassword string
|
||||
if e.SMTPPassword != "" {
|
||||
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
|
||||
}
|
||||
smtpPassword = decrypted
|
||||
}
|
||||
|
||||
// Compose email
|
||||
from := e.SMTPUser
|
||||
from := e.From
|
||||
if from == "" {
|
||||
from = "noreply@" + e.SMTPHost
|
||||
from = e.SMTPUser
|
||||
if from == "" {
|
||||
from = "noreply@" + e.SMTPHost
|
||||
}
|
||||
}
|
||||
|
||||
to := []string{e.TargetEmail}
|
||||
@@ -72,15 +92,16 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
)
|
||||
body := message
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", from)
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
|
||||
|
||||
// Combine all parts of the email
|
||||
emailContent := []byte(fromHeader + subject + mime + body)
|
||||
emailContent := []byte(fromHeader + toHeader + subject + mime + body)
|
||||
|
||||
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
|
||||
timeout := DefaultTimeout
|
||||
|
||||
// Determine if authentication is required
|
||||
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
|
||||
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
|
||||
|
||||
// Handle different port scenarios
|
||||
if e.SMTPPort == ImplicitTLSPort {
|
||||
@@ -111,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Set up authentication only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -174,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Authenticate only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -208,3 +229,30 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
return client.Quit()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) HideSensitiveData() {
|
||||
e.SMTPPassword = ""
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Update(incoming *EmailNotifier) {
|
||||
e.TargetEmail = incoming.TargetEmail
|
||||
e.SMTPHost = incoming.SMTPHost
|
||||
e.SMTPPort = incoming.SMTPPort
|
||||
e.SMTPUser = incoming.SMTPUser
|
||||
e.From = incoming.From
|
||||
|
||||
if incoming.SMTPPassword != "" {
|
||||
e.SMTPPassword = incoming.SMTPPassword
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if e.SMTPPassword != "" {
|
||||
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
|
||||
}
|
||||
e.SMTPPassword = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,7 +24,7 @@ type SlackNotifier struct {
|
||||
|
||||
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
|
||||
|
||||
func (s *SlackNotifier) Validate() error {
|
||||
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
|
||||
func (s *SlackNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
full := fmt.Sprintf("*%s*", heading)
|
||||
|
||||
if message != "" {
|
||||
@@ -80,7 +90,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", "Bearer "+s.BotToken)
|
||||
req.Header.Set("Authorization", "Bearer "+botToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
@@ -132,3 +142,26 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) HideSensitiveData() {
|
||||
s.BotToken = ""
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) Update(incoming *SlackNotifier) {
|
||||
s.TargetChatID = incoming.TargetChatID
|
||||
|
||||
if incoming.BotToken != "" {
|
||||
s.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
s.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -21,11 +22,17 @@ func (TeamsNotifier) TableName() string {
|
||||
return "teams_notifiers"
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Validate() error {
|
||||
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL == "" {
|
||||
return errors.New("webhook_url is required")
|
||||
}
|
||||
u, err := url.Parse(n.WebhookURL)
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return errors.New("invalid webhook_url")
|
||||
}
|
||||
@@ -33,8 +40,8 @@ func (n *TeamsNotifier) Validate() error {
|
||||
}
|
||||
|
||||
type cardAttachment struct {
|
||||
ContentType string `json:"contentType"`
|
||||
Content interface{} `json:"content"`
|
||||
ContentType string `json:"contentType"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
@@ -43,11 +50,20 @@ type payload struct {
|
||||
Attachments []cardAttachment `json:"attachments,omitempty"`
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error {
|
||||
if err := n.Validate(); err != nil {
|
||||
func (n *TeamsNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
if err := n.Validate(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
card := map[string]any{
|
||||
"type": "AdaptiveCard",
|
||||
"version": "1.4",
|
||||
@@ -71,7 +87,7 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(p)
|
||||
req, err := http.NewRequest(http.MethodPost, n.WebhookURL, bytes.NewReader(body))
|
||||
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -94,3 +110,24 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) HideSensitiveData() {
|
||||
n.WebhookURL = ""
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
|
||||
if incoming.WebhookURL != "" {
|
||||
n.WebhookURL = incoming.WebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
n.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
|
||||
return "telegram_notifiers"
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Validate() error {
|
||||
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *TelegramNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("chat_id", t.TargetChatID)
|
||||
@@ -80,3 +91,27 @@ func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message str
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) HideSensitiveData() {
|
||||
t.BotToken = ""
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
|
||||
t.TargetChatID = incoming.TargetChatID
|
||||
t.ThreadID = incoming.ThreadID
|
||||
|
||||
if incoming.BotToken != "" {
|
||||
t.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
t.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
|
||||
return "webhook_notifiers"
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Validate() error {
|
||||
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *WebhookNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
|
||||
t.WebhookURL,
|
||||
webhookURL,
|
||||
url.QueryEscape(heading),
|
||||
url.QueryEscape(message),
|
||||
)
|
||||
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
|
||||
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send POST webhook: %w", err)
|
||||
}
|
||||
@@ -102,3 +113,22 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) HideSensitiveData() {
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
t.WebhookURL = incoming.WebhookURL
|
||||
t.WebhookMethod = incoming.WebhookMethod
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -143,7 +143,7 @@ func (r *NotifierRepository) FindByID(id uuid.UUID) (*Notifier, error) {
|
||||
return ¬ifier, nil
|
||||
}
|
||||
|
||||
func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error) {
|
||||
func (r *NotifierRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Notifier, error) {
|
||||
var notifiers []*Notifier
|
||||
|
||||
if err := storage.
|
||||
@@ -154,7 +154,7 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
|
||||
Preload("SlackNotifier").
|
||||
Preload("DiscordNotifier").
|
||||
Preload("TeamsNotifier").
|
||||
Where("user_id = ?", userID).
|
||||
Where("workspace_id = ?", workspaceID).
|
||||
Order("name ASC").
|
||||
Find(¬ifiers).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -165,7 +165,6 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
|
||||
|
||||
func (r *NotifierRepository) Delete(notifier *Notifier) error {
|
||||
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
||||
|
||||
switch notifier.NotifierType {
|
||||
case NotifierTypeTelegram:
|
||||
if notifier.TelegramNotifier != nil {
|
||||
|
||||
@@ -2,8 +2,13 @@ package notifiers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -11,30 +16,77 @@ import (
|
||||
type NotifierService struct {
|
||||
notifierRepository *NotifierRepository
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *NotifierService) SaveNotifier(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
notifier *Notifier,
|
||||
) error {
|
||||
if notifier.ID != uuid.Nil {
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to manage notifier in this workspace")
|
||||
}
|
||||
|
||||
isUpdate := notifier.ID != uuid.Nil
|
||||
|
||||
if isUpdate {
|
||||
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existingNotifier.UserID != user.ID {
|
||||
return errors.New("you have not access to this notifier")
|
||||
if existingNotifier.WorkspaceID != workspaceID {
|
||||
return errors.New("notifier does not belong to this workspace")
|
||||
}
|
||||
|
||||
notifier.UserID = existingNotifier.UserID
|
||||
} else {
|
||||
notifier.UserID = user.ID
|
||||
}
|
||||
existingNotifier.Update(notifier)
|
||||
|
||||
_, err := s.notifierRepository.Save(notifier)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.notifierRepository.Save(existingNotifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
} else {
|
||||
notifier.WorkspaceID = workspaceID
|
||||
|
||||
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := notifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.notifierRepository.Save(notifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier created: %s", notifier.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -49,11 +101,26 @@ func (s *NotifierService) DeleteNotifier(
|
||||
return err
|
||||
}
|
||||
|
||||
if notifier.UserID != user.ID {
|
||||
return errors.New("you have not access to this notifier")
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(notifier.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to manage notifier in this workspace")
|
||||
}
|
||||
|
||||
return s.notifierRepository.Delete(notifier)
|
||||
err = s.notifierRepository.Delete(notifier)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Notifier deleted: %s", notifier.Name),
|
||||
&user.ID,
|
||||
¬ifier.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *NotifierService) GetNotifier(
|
||||
@@ -65,17 +132,40 @@ func (s *NotifierService) GetNotifier(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if notifier.UserID != user.ID {
|
||||
return nil, errors.New("you have not access to this notifier")
|
||||
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canView {
|
||||
return nil, errors.New("insufficient permissions to view notifier in this workspace")
|
||||
}
|
||||
|
||||
notifier.HideSensitiveData()
|
||||
return notifier, nil
|
||||
}
|
||||
|
||||
func (s *NotifierService) GetNotifiers(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
) ([]*Notifier, error) {
|
||||
return s.notifierRepository.FindByUserID(user.ID)
|
||||
canView, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canView {
|
||||
return nil, errors.New("insufficient permissions to view notifiers in this workspace")
|
||||
}
|
||||
|
||||
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, notifier := range notifiers {
|
||||
notifier.HideSensitiveData()
|
||||
}
|
||||
|
||||
return notifiers, nil
|
||||
}
|
||||
|
||||
func (s *NotifierService) SendTestNotification(
|
||||
@@ -87,11 +177,15 @@ func (s *NotifierService) SendTestNotification(
|
||||
return err
|
||||
}
|
||||
|
||||
if notifier.UserID != user.ID {
|
||||
return errors.New("you have not access to this notifier")
|
||||
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canView {
|
||||
return errors.New("insufficient permissions to test notifier in this workspace")
|
||||
}
|
||||
|
||||
err = notifier.Send(s.logger, "Test message", "This is a test message")
|
||||
err = notifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -107,7 +201,38 @@ func (s *NotifierService) SendTestNotification(
|
||||
func (s *NotifierService) SendTestNotificationToNotifier(
|
||||
notifier *Notifier,
|
||||
) error {
|
||||
return notifier.Send(s.logger, "Test message", "This is a test message")
|
||||
var usingNotifier *Notifier
|
||||
|
||||
if notifier.ID != uuid.Nil {
|
||||
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existingNotifier.WorkspaceID != notifier.WorkspaceID {
|
||||
return errors.New("notifier does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingNotifier.Update(notifier)
|
||||
|
||||
if err := existingNotifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingNotifier.Validate(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingNotifier = existingNotifier
|
||||
} else {
|
||||
if err := notifier.EncryptSensitiveData(s.fieldEncryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingNotifier = notifier
|
||||
}
|
||||
|
||||
return usingNotifier.Send(s.fieldEncryptor, s.logger, "Test message", "This is a test message")
|
||||
}
|
||||
|
||||
func (s *NotifierService) SendNotification(
|
||||
@@ -126,7 +251,7 @@ func (s *NotifierService) SendNotification(
|
||||
return
|
||||
}
|
||||
|
||||
err = notifiedFromDb.Send(s.logger, title, message)
|
||||
err = notifiedFromDb.Send(s.fieldEncryptor, s.logger, title, message)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
notifiedFromDb.LastSendError = &errMsg
|
||||
@@ -143,3 +268,18 @@ func (s *NotifierService) SendNotification(
|
||||
s.logger.Error("Failed to save notifier", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *NotifierService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get notifiers for workspace deletion: %w", err)
|
||||
}
|
||||
|
||||
for _, notifier := range notifiers {
|
||||
if err := s.notifierRepository.Delete(notifier); err != nil {
|
||||
return fmt.Errorf("failed to delete notifier %s: %w", notifier.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func CreateTestNotifier(userID uuid.UUID) *Notifier {
|
||||
func CreateTestNotifier(workspaceID uuid.UUID) *Notifier {
|
||||
notifier := &Notifier{
|
||||
UserID: userID,
|
||||
WorkspaceID: workspaceID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
NotifierType: NotifierTypeWebhook,
|
||||
WebhookNotifier: &webhook_notifier.WebhookNotifier{
|
||||
|
||||
@@ -2,7 +2,7 @@ package restores
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type RestoreController struct {
|
||||
restoreService *RestoreService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -29,24 +28,18 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 401
|
||||
// @Router /restores/{backupId} [get]
|
||||
func (c *RestoreController) GetRestores(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(ctx.Param("backupId"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
restores, err := c.restoreService.GetRestores(user, backupID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -66,6 +59,12 @@ func (c *RestoreController) GetRestores(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /restores/{backupId}/restore [post]
|
||||
func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
backupID, err := uuid.Parse(ctx.Param("backupId"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
@@ -78,18 +77,6 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.restoreService.RestoreBackupWithAuth(user, backupID, requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
|
||||
348
backend/internal/features/restores/controller_test.go
Normal file
348
backend/internal/features/restores/controller_test.go
Normal file
@@ -0,0 +1,348 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
users_dto "postgresus-backend/internal/features/users/dto"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_models "postgresus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
backups.GetBackupController(),
|
||||
GetRestoreController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
var restores []*models.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&restores,
|
||||
)
|
||||
|
||||
assert.NotNil(t, restores)
|
||||
assert.Equal(t, 0, len(restores))
|
||||
assert.NotNil(t, database)
|
||||
}
|
||||
|
||||
func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
|
||||
"Bearer "+nonMember.Token,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
var restores []*models.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
|
||||
"Bearer "+admin.Token,
|
||||
http.StatusOK,
|
||||
&restores,
|
||||
)
|
||||
|
||||
assert.NotNil(t, restores)
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+nonMember.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
},
|
||||
}
|
||||
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Database restored from backup") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for restore not found")
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackupForRestore(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *backups.Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
|
||||
func createTestDatabase(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
storage := &storages.Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test Storage " + uuid.New().String(),
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
repo := &storages.StorageRepository{}
|
||||
storage, err := repo.Save(storage)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *backups.Backup {
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &backups.Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: backups.BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &backups.BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
@@ -1,12 +1,13 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/restores/usecases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -19,10 +20,11 @@ var restoreService = &RestoreService{
|
||||
usecases.GetRestoreBackupUsecase(),
|
||||
databases.GetDatabaseService(),
|
||||
logger.GetLogger(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
}
|
||||
var restoreController = &RestoreController{
|
||||
restoreService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
var restoreBackgroundService = &RestoreBackgroundService{
|
||||
|
||||
@@ -62,8 +62,6 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Backup.Storage").
|
||||
Preload("Backup.Database").
|
||||
Preload("Backup").
|
||||
Preload("Postgresql").
|
||||
Where("status = ?", status).
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"postgresus-backend/internal/features/restores/usecases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
"time"
|
||||
|
||||
@@ -26,6 +28,8 @@ type RestoreService struct {
|
||||
restoreBackupUsecase *usecases.RestoreBackupUsecase
|
||||
databaseService *databases.DatabaseService
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
|
||||
@@ -58,8 +62,24 @@ func (s *RestoreService) GetRestores(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this backup")
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot get restores for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*database.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access restores for this backup")
|
||||
}
|
||||
|
||||
return s.restoreRepository.FindByBackupID(backupID)
|
||||
@@ -75,8 +95,24 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this backup")
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot restore backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*database.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canAccess {
|
||||
return errors.New("insufficient permissions to restore this backup")
|
||||
}
|
||||
|
||||
backupDatabase, err := s.databaseService.GetDatabase(user, backup.DatabaseID)
|
||||
@@ -105,6 +141,16 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
}
|
||||
}()
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Database restored from backup %s for database: %s",
|
||||
backupID.String(),
|
||||
database.Name,
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -116,7 +162,12 @@ func (s *RestoreService) RestoreBackup(
|
||||
return errors.New("backup is not completed")
|
||||
}
|
||||
|
||||
if backup.Database.Type == databases.DatabaseTypePostgres {
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
if requestDTO.PostgresqlDatabase == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
@@ -157,7 +208,7 @@ func (s *RestoreService) RestoreBackup(
|
||||
}
|
||||
|
||||
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
|
||||
backup.Database.ID,
|
||||
database.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -168,6 +219,7 @@ func (s *RestoreService) RestoreBackup(
|
||||
err = s.restoreBackupUsecase.Execute(
|
||||
backupConfig,
|
||||
restore,
|
||||
database,
|
||||
backup,
|
||||
storage,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import (
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restorePostgresqlBackupUsecase = &RestorePostgresqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
users_repositories.GetSecretKeyRepository(),
|
||||
}
|
||||
|
||||
func GetRestorePostgresqlBackupUsecase() *RestorePostgresqlBackupUsecase {
|
||||
|
||||
@@ -2,6 +2,7 @@ package usecases_postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -15,11 +16,14 @@ import (
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_repositories "postgresus-backend/internal/features/users/repositories"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
@@ -27,16 +31,18 @@ import (
|
||||
)
|
||||
|
||||
type RestorePostgresqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
logger *slog.Logger
|
||||
secretKeyRepo *users_repositories.SecretKeyRepository
|
||||
}
|
||||
|
||||
func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
database *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if backup.Database.Type != databases.DatabaseTypePostgres {
|
||||
if database.Type != databases.DatabaseTypePostgres {
|
||||
return errors.New("database type not supported")
|
||||
}
|
||||
|
||||
@@ -76,6 +82,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
database,
|
||||
tools.GetPostgresqlExecutable(
|
||||
pg.Version,
|
||||
"pg_restore",
|
||||
@@ -92,6 +99,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// restoreFromStorage restores backup data from storage using pg_restore
|
||||
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
database *databases.Database,
|
||||
pgBin string,
|
||||
args []string,
|
||||
password string,
|
||||
@@ -164,7 +172,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
// Add the temporary backup file as the last argument to pg_restore
|
||||
args = append(args, tempBackupFile)
|
||||
|
||||
return uc.executePgRestore(ctx, pgBin, args, pgpassFile, pgConfig, backup)
|
||||
return uc.executePgRestore(ctx, database, pgBin, args, pgpassFile, pgConfig)
|
||||
}
|
||||
|
||||
// downloadBackupToTempFile downloads backup data from storage to a temporary file
|
||||
@@ -199,18 +207,67 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
backup.ID,
|
||||
"tempFile",
|
||||
tempBackupFile,
|
||||
"encrypted",
|
||||
backup.Encryption == backups_config.BackupEncryptionEncrypted,
|
||||
)
|
||||
backupReader, err := storage.GetFile(backup.ID)
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := backupReader.Close(); err != nil {
|
||||
if err := rawReader.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close backup reader", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create a reader that handles decryption if needed
|
||||
var backupReader io.Reader = rawReader
|
||||
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||
// Validate encryption metadata
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
}
|
||||
|
||||
// Get master key
|
||||
masterKey, err := uc.secretKeyRepo.GetSecretKey()
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get master key for decryption: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and IV from base64
|
||||
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to decode encryption salt: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to decode encryption IV: %w", err)
|
||||
}
|
||||
|
||||
// Create decryption reader
|
||||
decryptReader, err := encryption.NewDecryptionReader(
|
||||
rawReader,
|
||||
masterKey,
|
||||
backup.ID,
|
||||
salt,
|
||||
iv,
|
||||
)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to create decryption reader: %w", err)
|
||||
}
|
||||
|
||||
backupReader = decryptReader
|
||||
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
|
||||
}
|
||||
|
||||
// Create temporary backup file
|
||||
tempFile, err := os.Create(tempBackupFile)
|
||||
if err != nil {
|
||||
@@ -240,11 +297,11 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
|
||||
// executePgRestore executes the pg_restore command with proper environment setup
|
||||
func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
pgBin string,
|
||||
args []string,
|
||||
pgpassFile string,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
backup *backups.Backup,
|
||||
) error {
|
||||
cmd := exec.CommandContext(ctx, pgBin, args...)
|
||||
uc.logger.Info("Executing PostgreSQL restore command", "command", cmd.String())
|
||||
@@ -293,7 +350,7 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
return uc.handlePgRestoreError(waitErr, stderrOutput, pgBin, args, backup, pgConfig)
|
||||
return uc.handlePgRestoreError(database, waitErr, stderrOutput, pgBin, args, pgConfig)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -341,11 +398,11 @@ func (uc *RestorePostgresqlBackupUsecase) setupPgRestoreEnvironment(
|
||||
|
||||
// handlePgRestoreError processes and formats pg_restore errors
|
||||
func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
|
||||
database *databases.Database,
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
pgBin string,
|
||||
args []string,
|
||||
backup *backups.Backup,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
) error {
|
||||
// Enhanced error handling for PostgreSQL connection and restore issues
|
||||
@@ -416,8 +473,8 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
|
||||
backupDbName := "unknown"
|
||||
if backup.Database != nil && backup.Database.Postgresql != nil && backup.Database.Postgresql.Database != nil {
|
||||
backupDbName = *backup.Database.Postgresql.Database
|
||||
if database.Postgresql != nil && database.Postgresql.Database != nil {
|
||||
backupDbName = *database.Postgresql.Database
|
||||
}
|
||||
|
||||
targetDbName := "unknown"
|
||||
|
||||
@@ -17,11 +17,13 @@ type RestoreBackupUsecase struct {
|
||||
func (uc *RestoreBackupUsecase) Execute(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
database *databases.Database,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if restore.Backup.Database.Type == databases.DatabaseTypePostgres {
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
return uc.restorePostgresqlBackupUsecase.Execute(
|
||||
database,
|
||||
backupConfig,
|
||||
restore,
|
||||
backup,
|
||||
|
||||
@@ -2,15 +2,16 @@ package storages
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type StorageController struct {
|
||||
storageService *StorageService
|
||||
userService *users.UserService
|
||||
storageService *StorageService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -29,35 +30,40 @@ func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param storage body Storage true "Storage data"
|
||||
// @Param request body Storage true "Storage data with workspaceId"
|
||||
// @Success 200 {object} Storage
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages [post]
|
||||
func (c *StorageController) SaveStorage(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var storage Storage
|
||||
if err := ctx.ShouldBindJSON(&storage); err != nil {
|
||||
var request Storage
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := storage.Validate(); err != nil {
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage storage in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.storageService.SaveStorage(user, &storage); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, storage)
|
||||
ctx.JSON(http.StatusOK, request)
|
||||
}
|
||||
|
||||
// GetStorage
|
||||
@@ -70,11 +76,12 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
|
||||
// @Success 200 {object} Storage
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages/{id} [get]
|
||||
func (c *StorageController) GetStorage(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,6 +93,10 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
|
||||
|
||||
storage, err := c.storageService.GetStorage(user, id)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view storage in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -95,22 +106,41 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
|
||||
|
||||
// GetStorages
|
||||
// @Summary Get all storages
|
||||
// @Description Get all storages for the current user
|
||||
// @Description Get all storages for a workspace
|
||||
// @Tags storages
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param workspace_id query string true "Workspace ID"
|
||||
// @Success 200 {array} Storage
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages [get]
|
||||
func (c *StorageController) GetStorages(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
storages, err := c.storageService.GetStorages(user)
|
||||
workspaceIDStr := ctx.Query("workspace_id")
|
||||
if workspaceIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
workspaceID, err := uuid.Parse(workspaceIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
|
||||
return
|
||||
}
|
||||
|
||||
storages, err := c.storageService.GetStorages(user, workspaceID)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view storages in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -128,11 +158,12 @@ func (c *StorageController) GetStorages(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages/{id} [delete]
|
||||
func (c *StorageController) DeleteStorage(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -143,6 +174,10 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.storageService.DeleteStorage(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage storage in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -160,11 +195,12 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages/{id}/test [post]
|
||||
func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -175,6 +211,10 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.storageService.TestStorageConnection(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to test storage in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -189,33 +229,44 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param storage body Storage true "Storage data"
|
||||
// @Param request body Storage true "Storage data with workspaceId"
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /storages/direct-test [post]
|
||||
func (c *StorageController) TestStorageConnectionDirect(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Storage
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var storage Storage
|
||||
if err := ctx.ShouldBindJSON(&storage); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// For direct test, associate with the current user
|
||||
storage.UserID = user.ID
|
||||
|
||||
if err := storage.Validate(); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
if !canView {
|
||||
ctx.JSON(
|
||||
http.StatusForbidden,
|
||||
gin.H{"error": "insufficient permissions to test storage in this workspace"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.storageService.TestStorageConnectionDirect(&storage); err != nil {
|
||||
if err := c.storageService.TestStorageConnectionDirect(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,25 +1,46 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
"postgresus-backend/internal/features/users"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
|
||||
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(user.UserID)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
verifyStorageData(t, storage, &savedStorage)
|
||||
@@ -30,8 +51,8 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages/"+savedStorage.ID.String(),
|
||||
user.Token,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&retrievedStorage,
|
||||
)
|
||||
@@ -41,181 +62,788 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
|
||||
// Verify storage is returned via GET all storages
|
||||
var storages []Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&storages,
|
||||
)
|
||||
|
||||
assert.Contains(t, storages, savedStorage)
|
||||
|
||||
RemoveTestStorage(savedStorage.ID)
|
||||
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_UpdateExistingStorage_UpdatedStorageReturnedViaGet(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(user.UserID)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
// Save initial storage
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
// Modify storage name
|
||||
updatedName := "Updated Storage " + uuid.New().String()
|
||||
savedStorage.Name = updatedName
|
||||
|
||||
// Update storage
|
||||
var updatedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, savedStorage, http.StatusOK, &updatedStorage,
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
savedStorage,
|
||||
http.StatusOK,
|
||||
&updatedStorage,
|
||||
)
|
||||
|
||||
// Verify updated data
|
||||
assert.Equal(t, updatedName, updatedStorage.Name)
|
||||
assert.Equal(t, savedStorage.ID, updatedStorage.ID)
|
||||
|
||||
// Verify through GET
|
||||
var retrievedStorage Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages/"+updatedStorage.ID.String(),
|
||||
user.Token,
|
||||
http.StatusOK,
|
||||
&retrievedStorage,
|
||||
)
|
||||
|
||||
verifyStorageData(t, &updatedStorage, &retrievedStorage)
|
||||
|
||||
// Verify storage is returned via GET all storages
|
||||
var storages []Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
|
||||
)
|
||||
|
||||
assert.Contains(t, storages, updatedStorage)
|
||||
|
||||
RemoveTestStorage(updatedStorage.ID)
|
||||
deleteStorage(t, router, updatedStorage.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_DeleteStorage_StorageNotReturnedViaGet(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(user.UserID)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
// Save initial storage
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
// Delete storage
|
||||
test_utils.MakeDeleteRequest(
|
||||
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusOK,
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
// Try to get deleted storage, should return error
|
||||
response := test_utils.MakeGetRequest(
|
||||
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusBadRequest,
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(response.Body), "error")
|
||||
|
||||
// Verify storage is not returned via GET all storages
|
||||
var storages []Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
|
||||
)
|
||||
|
||||
assert.NotContains(t, storages, savedStorage)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_TestDirectStorageConnection_ConnectionEstablished(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(user.UserID)
|
||||
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
response := test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/storages/direct-test", user.Token, storage, http.StatusOK,
|
||||
t, router, "/api/v1/storages/direct-test", "Bearer "+owner.Token, *storage, http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(response.Body), "successful")
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_TestExistingStorageConnection_ConnectionEstablished(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(user.UserID)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
// Test connection to existing storage
|
||||
response := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages/"+savedStorage.ID.String()+"/test",
|
||||
user.Token,
|
||||
fmt.Sprintf("/api/v1/storages/%s/test", savedStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(response.Body), "successful")
|
||||
|
||||
RemoveTestStorage(savedStorage.ID)
|
||||
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_CallAllMethodsWithoutAuth_UnauthorizedErrorReturned(t *testing.T) {
|
||||
func Test_ViewerCanViewStorages_ButCannotModify(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
storage := createNewStorage(uuid.New())
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
viewer,
|
||||
users_enums.WorkspaceRoleViewer,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
// Test endpoints without auth
|
||||
endpoints := []struct {
|
||||
method string
|
||||
url string
|
||||
body interface{}
|
||||
}{
|
||||
{"GET", "/api/v1/storages", nil},
|
||||
{"GET", "/api/v1/storages/" + uuid.New().String(), nil},
|
||||
{"POST", "/api/v1/storages", storage},
|
||||
{"DELETE", "/api/v1/storages/" + uuid.New().String(), nil},
|
||||
{"POST", "/api/v1/storages/" + uuid.New().String() + "/test", nil},
|
||||
{"POST", "/api/v1/storages/direct-test", storage},
|
||||
}
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
for _, endpoint := range endpoints {
|
||||
testUnauthorizedEndpoint(t, router, endpoint.method, endpoint.url, endpoint.body)
|
||||
}
|
||||
// Viewer can GET storages
|
||||
var storages []Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
|
||||
"Bearer "+viewer.Token,
|
||||
http.StatusOK,
|
||||
&storages,
|
||||
)
|
||||
assert.Len(t, storages, 1)
|
||||
|
||||
// Viewer cannot CREATE storage
|
||||
newStorage := createNewStorage(workspace.ID)
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/storages", "Bearer "+viewer.Token, *newStorage, http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Viewer cannot UPDATE storage
|
||||
savedStorage.Name = "Updated by viewer"
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/storages", "Bearer "+viewer.Token, savedStorage, http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Viewer cannot DELETE storage
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+viewer.Token,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testUnauthorizedEndpoint(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
method, url string,
|
||||
body interface{},
|
||||
) {
|
||||
test_utils.MakeRequest(t, router, test_utils.RequestOptions{
|
||||
Method: method,
|
||||
URL: url,
|
||||
Body: body,
|
||||
ExpectedStatus: http.StatusUnauthorized,
|
||||
})
|
||||
func Test_MemberCanManageStorages(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
member,
|
||||
users_enums.WorkspaceRoleMember,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
// Member can CREATE storage
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+member.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
assert.NotEmpty(t, savedStorage.ID)
|
||||
|
||||
// Member can UPDATE storage
|
||||
savedStorage.Name = "Updated by member"
|
||||
var updatedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+member.Token,
|
||||
savedStorage,
|
||||
http.StatusOK,
|
||||
&updatedStorage,
|
||||
)
|
||||
assert.Equal(t, "Updated by member", updatedStorage.Name)
|
||||
|
||||
// Member can DELETE storage
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+member.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_AdminCanManageStorages(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
workspaces_testing.AddMemberToWorkspace(
|
||||
workspace,
|
||||
admin,
|
||||
users_enums.WorkspaceRoleAdmin,
|
||||
owner.Token,
|
||||
router,
|
||||
)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
// Admin can CREATE, UPDATE, DELETE
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+admin.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
savedStorage.Name = "Updated by admin"
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/storages", "Bearer "+admin.Token, savedStorage, http.StatusOK,
|
||||
)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+admin.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
storage := createNewStorage(workspace.ID)
|
||||
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*storage,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
// Outsider cannot GET storages
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
|
||||
"Bearer "+outsider.Token,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Outsider cannot CREATE storage
|
||||
test_utils.MakePostRequest(
|
||||
t, router, "/api/v1/storages", "Bearer "+outsider.Token, *storage, http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Outsider cannot UPDATE storage
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+outsider.Token,
|
||||
savedStorage,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
// Outsider cannot DELETE storage
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+outsider.Token,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
|
||||
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *testing.T) {
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
|
||||
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
|
||||
storage1 := createNewStorage(workspace1.ID)
|
||||
|
||||
var savedStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner1.Token,
|
||||
*storage1,
|
||||
http.StatusOK,
|
||||
&savedStorage,
|
||||
)
|
||||
|
||||
// Try to access workspace1's storage with owner2 from workspace2
|
||||
response := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
|
||||
"Bearer "+owner2.Token,
|
||||
http.StatusForbidden,
|
||||
)
|
||||
assert.Contains(t, string(response.Body), "insufficient permissions")
|
||||
|
||||
deleteStorage(t, router, savedStorage.ID, workspace1.ID, owner1.Token)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace1, router)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace2, router)
|
||||
}
|
||||
|
||||
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
storageType StorageType
|
||||
createStorage func(workspaceID uuid.UUID) *Storage
|
||||
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
|
||||
verifySensitiveData func(t *testing.T, storage *Storage)
|
||||
verifyHiddenData func(t *testing.T, storage *Storage)
|
||||
}{
|
||||
{
|
||||
name: "S3 Storage",
|
||||
storageType: StorageTypeS3,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Test S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "test-bucket",
|
||||
S3Region: "us-east-1",
|
||||
S3AccessKey: "original-access-key",
|
||||
S3SecretKey: "original-secret-key",
|
||||
S3Endpoint: "https://s3.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeS3,
|
||||
Name: "Updated S3 Storage",
|
||||
S3Storage: &s3_storage.S3Storage{
|
||||
S3Bucket: "updated-bucket",
|
||||
S3Region: "us-west-2",
|
||||
S3AccessKey: "",
|
||||
S3SecretKey: "",
|
||||
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.S3Storage.S3AccessKey, "enc:"),
|
||||
"S3AccessKey should be encrypted with 'enc:' prefix")
|
||||
assert.True(t, strings.HasPrefix(storage.S3Storage.S3SecretKey, "enc:"),
|
||||
"S3SecretKey should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
accessKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3AccessKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-access-key", accessKey)
|
||||
|
||||
secretKey, err := encryptor.Decrypt(storage.ID, storage.S3Storage.S3SecretKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-secret-key", secretKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
|
||||
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Local Storage",
|
||||
storageType: StorageTypeLocal,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Test Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Updated Local Storage",
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "NAS Storage",
|
||||
storageType: StorageTypeNAS,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeNAS,
|
||||
Name: "Test NAS Storage",
|
||||
NASStorage: &nas_storage.NASStorage{
|
||||
Host: "nas.example.com",
|
||||
Port: 445,
|
||||
Share: "backups",
|
||||
Username: "testuser",
|
||||
Password: "original-password",
|
||||
UseSSL: false,
|
||||
Domain: "WORKGROUP",
|
||||
Path: "/test",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeNAS,
|
||||
Name: "Updated NAS Storage",
|
||||
NASStorage: &nas_storage.NASStorage{
|
||||
Host: "nas2.example.com",
|
||||
Port: 445,
|
||||
Share: "backups2",
|
||||
Username: "testuser2",
|
||||
Password: "",
|
||||
UseSSL: true,
|
||||
Domain: "WORKGROUP2",
|
||||
Path: "/test2",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.NASStorage.Password, "enc:"),
|
||||
"Password should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
password, err := encryptor.Decrypt(storage.ID, storage.NASStorage.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password", password)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.NASStorage.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure Blob Storage (Connection String)",
|
||||
storageType: StorageTypeAzureBlob,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Test Azure Blob Storage",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
|
||||
ConnectionString: "original-connection-string",
|
||||
ContainerName: "test-container",
|
||||
Endpoint: "",
|
||||
Prefix: "backups/",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Updated Azure Blob Storage",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodConnectionString,
|
||||
ConnectionString: "",
|
||||
ContainerName: "updated-container",
|
||||
Endpoint: "https://custom.blob.core.windows.net",
|
||||
Prefix: "backups2/",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.ConnectionString, "enc:"),
|
||||
"ConnectionString should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
connectionString, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.AzureBlobStorage.ConnectionString,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-connection-string", connectionString)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Azure Blob Storage (Account Key)",
|
||||
storageType: StorageTypeAzureBlob,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Test Azure Blob with Account Key",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
|
||||
AccountName: "testaccount",
|
||||
AccountKey: "original-account-key",
|
||||
ContainerName: "test-container",
|
||||
Endpoint: "",
|
||||
Prefix: "backups/",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeAzureBlob,
|
||||
Name: "Updated Azure Blob with Account Key",
|
||||
AzureBlobStorage: &azure_blob_storage.AzureBlobStorage{
|
||||
AuthMethod: azure_blob_storage.AuthMethodAccountKey,
|
||||
AccountName: "updatedaccount",
|
||||
AccountKey: "",
|
||||
ContainerName: "updated-container",
|
||||
Endpoint: "https://custom.blob.core.windows.net",
|
||||
Prefix: "backups2/",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.AzureBlobStorage.AccountKey, "enc:"),
|
||||
"AccountKey should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
accountKey, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.AzureBlobStorage.AccountKey,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-account-key", accountKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.ConnectionString)
|
||||
assert.Equal(t, "", storage.AzureBlobStorage.AccountKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Google Drive Storage",
|
||||
storageType: StorageTypeGoogleDrive,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeGoogleDrive,
|
||||
Name: "Test Google Drive Storage",
|
||||
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
|
||||
ClientID: "original-client-id",
|
||||
ClientSecret: "original-client-secret",
|
||||
TokenJSON: `{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeGoogleDrive,
|
||||
Name: "Updated Google Drive Storage",
|
||||
GoogleDriveStorage: &google_drive_storage.GoogleDriveStorage{
|
||||
ClientID: "updated-client-id",
|
||||
ClientSecret: "",
|
||||
TokenJSON: "",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.ClientSecret, "enc:"),
|
||||
"ClientSecret should be encrypted with 'enc:' prefix")
|
||||
assert.True(t, strings.HasPrefix(storage.GoogleDriveStorage.TokenJSON, "enc:"),
|
||||
"TokenJSON should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
clientSecret, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.GoogleDriveStorage.ClientSecret,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-client-secret", clientSecret)
|
||||
|
||||
tokenJSON, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.GoogleDriveStorage.TokenJSON,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
`{"access_token":"ya29.test-access-token","token_type":"Bearer","expiry":"2030-12-31T23:59:59Z","refresh_token":"1//test-refresh-token"}`,
|
||||
tokenJSON,
|
||||
)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.GoogleDriveStorage.ClientSecret)
|
||||
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
// Phase 1: Create storage with sensitive data
|
||||
initialStorage := tc.createStorage(workspace.ID)
|
||||
var createdStorage Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*initialStorage,
|
||||
http.StatusOK,
|
||||
&createdStorage,
|
||||
)
|
||||
|
||||
assert.NotEmpty(t, createdStorage.ID)
|
||||
assert.Equal(t, initialStorage.Name, createdStorage.Name)
|
||||
|
||||
// Phase 2: Verify sensitive data is encrypted in repository after creation
|
||||
repository := &StorageRepository{}
|
||||
storageFromDBAfterCreate, err := repository.FindByID(createdStorage.ID)
|
||||
assert.NoError(t, err)
|
||||
tc.verifySensitiveData(t, storageFromDBAfterCreate)
|
||||
|
||||
// Phase 3: Read via service - sensitive data should be hidden
|
||||
var retrievedStorage Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&retrievedStorage,
|
||||
)
|
||||
|
||||
tc.verifyHiddenData(t, &retrievedStorage)
|
||||
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
|
||||
|
||||
// Phase 4: Update with non-sensitive changes only (sensitive fields empty)
|
||||
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
|
||||
var updateResponse Storage
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/storages",
|
||||
"Bearer "+owner.Token,
|
||||
*updatedStorage,
|
||||
http.StatusOK,
|
||||
&updateResponse,
|
||||
)
|
||||
|
||||
// Verify non-sensitive fields were updated
|
||||
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
|
||||
|
||||
// Phase 5: Retrieve directly from repository to verify sensitive data preservation
|
||||
storageFromDB, err := repository.FindByID(createdStorage.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify original sensitive data is still present in DB
|
||||
tc.verifySensitiveData(t, storageFromDB)
|
||||
|
||||
// Verify non-sensitive fields were updated in DB
|
||||
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
|
||||
|
||||
// Additional verification: Check via GET that data is still hidden
|
||||
var finalRetrieved Storage
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&finalRetrieved,
|
||||
)
|
||||
tc.verifyHiddenData(t, &finalRetrieved)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
controller := GetStorageController()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
controller.RegisterRoutes(v1)
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
GetStorageController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
|
||||
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
|
||||
}
|
||||
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func createNewStorage(userID uuid.UUID) *Storage {
|
||||
func createNewStorage(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
UserID: userID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeLocal,
|
||||
Name: "Test Storage " + uuid.New().String(),
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
@@ -225,5 +853,20 @@ func createNewStorage(userID uuid.UUID) *Storage {
|
||||
func verifyStorageData(t *testing.T, expected *Storage, actual *Storage) {
|
||||
assert.Equal(t, expected.Name, actual.Name)
|
||||
assert.Equal(t, expected.Type, actual.Type)
|
||||
assert.Equal(t, expected.UserID, actual.UserID)
|
||||
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
|
||||
}
|
||||
|
||||
func deleteStorage(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
storageID, workspaceID uuid.UUID,
|
||||
token string,
|
||||
) {
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/storages/%s", storageID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/users"
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
var storageRepository = &StorageRepository{}
|
||||
var storageService = &StorageService{
|
||||
storageRepository,
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var storageController = &StorageController{
|
||||
storageService,
|
||||
users.GetUserService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
func GetStorageService() *StorageService {
|
||||
@@ -20,3 +25,7 @@ func GetStorageService() *StorageService {
|
||||
func GetStorageController() *StorageController {
|
||||
return storageController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ const (
|
||||
StorageTypeS3 StorageType = "S3"
|
||||
StorageTypeGoogleDrive StorageType = "GOOGLE_DRIVE"
|
||||
StorageTypeNAS StorageType = "NAS"
|
||||
StorageTypeAzureBlob StorageType = "AZURE_BLOB"
|
||||
)
|
||||
|
||||
@@ -3,18 +3,28 @@ package storages
|
||||
import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type StorageFileSaver interface {
|
||||
SaveFile(logger *slog.Logger, fileID uuid.UUID, file io.Reader) error
|
||||
SaveFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error
|
||||
|
||||
GetFile(fileID uuid.UUID) (io.ReadCloser, error)
|
||||
GetFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) (io.ReadCloser, error)
|
||||
|
||||
DeleteFile(fileID uuid.UUID) error
|
||||
DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error
|
||||
|
||||
Validate() error
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
TestConnection() error
|
||||
TestConnection(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
HideSensitiveData()
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user