Compare commits

...

21 Commits

Author SHA1 Message Date
Rostislav Dugin
069d6bc8fe FEATURE (logo): Update logo 2025-11-14 20:19:26 +03:00
Rostislav Dugin
242d5543d4 FIX (backups): Avoid possibility of breaking DB on backup fail 2025-11-14 19:56:56 +03:00
Rostislav Dugin
02c735bc5a FIX (notifiers): Improve email validation 2025-11-14 18:02:27 +03:00
Rostislav Dugin
793b575146 FIX (storages): Ignore files removal errors for unavailable storage when deleting the database 2025-11-14 18:02:13 +03:00
Rostislav Dugin
a6e84b45f2 Merge pull request #84 from RostislavDugin/feature/add_pg_12
Feature/add pg 12
2025-11-12 15:43:09 +03:00
Rostislav Dugin
a941fbd093 FEATURE (postgres): Add PostgreSQL 12 tests and CI \ CD config 2025-11-12 15:39:44 +03:00
Rostislav Dugin
4492ba41f5 Merge pull request #82 from romanesko/feature/v12-support
feat: add PostgreSQL 12 support
2025-11-12 15:04:12 +03:00
Roman Bykovsky
3a5ac4b479 feat: add PostgreSQL 12 support 2025-11-11 18:53:26 +03:00
Rostislav Dugin
77aaabeaa1 FEATURE (docs): Update readme and docs links 2025-11-11 16:56:33 +03:00
Rostislav Dugin
01911dbf72 FIX (notifiers & storages): Avoid request for workspace_id for storages and notifiers removal 2025-11-11 10:05:45 +03:00
Rostislav Dugin
1a16f27a5d FIX (notifiers): Fix update of existing DB notifiers 2025-11-11 08:10:02 +03:00
Rostislav Dugin
778db71625 FIX (tests): Improve tests stability in CI \ CD 2025-11-09 20:41:36 +03:00
Rostislav Dugin
45fc9a7fff FIX (databases): Verify DB nil on side of DB instead of interface 2025-11-09 20:03:22 +03:00
Rostislav Dugin
7f5e786261 FIX (databases): If some DB missing PostgreSQL db fix nil issue 2025-11-09 18:57:42 +03:00
Rostislav Dugin
9b066bcb8a FEATURE (email): Add "from" field 2025-11-08 20:47:35 +03:00
Rostislav Dugin
9ea795b48f FEATURE (backups): Add backups cancelling 2025-11-08 20:04:06 +03:00
Rostislav Dugin
a809dc8a9c FEATURE (protection): Do not expose sensetive data of databases, notifiers and storages from API + make backups lazy loaded 2025-11-08 18:49:23 +03:00
Rostislav Dugin
bd053b51a3 FIX (workspaces): Fix switch between workspaces 2025-11-07 15:54:15 +03:00
Rostislav Dugin
431e9861f4 FEATURE (workspaces): Add workspaces with users management and global Postgresus settings 2025-11-07 15:28:03 +03:00
Rostislav Dugin
de1fd4c4da Merge pull request #56 from SebasGDEV/fix/readmetypo
fix: typo redirecting to contribute/README.me
2025-11-02 11:53:03 +03:00
Sebastian G
df55fd17d5 fix: typo redirecting to contribute/README.me 2025-11-01 11:27:58 -07:00
287 changed files with 22190 additions and 2721 deletions

View File

@@ -127,6 +127,7 @@ jobs:
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
TEST_GOOGLE_DRIVE_TOKEN_JSON=${{ secrets.TEST_GOOGLE_DRIVE_TOKEN_JSON }}
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -138,6 +139,9 @@ jobs:
TEST_MINIO_CONSOLE_PORT=9001
# testing NAS
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
EOF
- name: Start test containers
@@ -151,6 +155,7 @@ jobs:
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
# Wait for test databases
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
@@ -182,7 +187,7 @@ jobs:
- name: Run Go tests
run: |
cd backend
go test ./internal/...
go test -p=1 -count=1 -failfast ./internal/...
- name: Stop test containers
if: always()

View File

@@ -77,7 +77,7 @@ ENV APP_VERSION=$APP_VERSION
# Set production mode for Docker containers
ENV ENV_MODE=production
# Install PostgreSQL server and client tools (versions 13-17)
# Install PostgreSQL server and client tools (versions 12-18)
RUN apt-get update && apt-get install -y --no-install-recommends \
wget ca-certificates gnupg lsb-release sudo gosu && \
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
@@ -85,7 +85,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
> /etc/apt/sources.list.d/pgdg.list && \
apt-get update && \
apt-get install -y --no-install-recommends \
postgresql-17 postgresql-18 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
rm -rf /var/lib/apt/lists/*

View File

@@ -187,7 +187,7 @@
same "license" line as the copyright notice for easier
identification within third-party archives.
Copyright 2025 LogBull
Copyright 2025 Postgresus
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -9,7 +9,7 @@
[![Docker Pulls](https://img.shields.io/docker/pulls/rostislavdugin/postgresus?color=brightgreen)](https://hub.docker.com/r/rostislavdugin/postgresus)
[![Platform](https://img.shields.io/badge/platform-linux%20%7C%20macos%20%7C%20windows-lightgrey)](https://github.com/RostislavDugin/postgresus)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-13%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![PostgreSQL](https://img.shields.io/badge/PostgreSQL-12%20%7C%2013%20%7C%2014%20%7C%2015%20%7C%2016%20%7C%2017%20%7C%2018-336791?logo=postgresql&logoColor=white)](https://www.postgresql.org/)
[![Self Hosted](https://img.shields.io/badge/self--hosted-yes-brightgreen)](https://github.com/RostislavDugin/postgresus)
[![Open Source](https://img.shields.io/badge/open%20source-❤️-red)](https://github.com/RostislavDugin/postgresus)
@@ -40,13 +40,13 @@
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
### 🗄️ **Multiple Storage Destinations**
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(docs)</a>
- **Local storage**: Keep backups on your VPS/server
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
- **Secure**: All data stays under your control
### 📱 **Smart Notifications**
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(docs)</a>
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
- **Real-time updates**: Success and failure notifications
@@ -54,17 +54,24 @@
### 🐘 **PostgreSQL Support**
- **Multiple versions**: PostgreSQL 13, 14, 15, 16, 17 and 18
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
- **SSL support**: Secure connections available
- **Easy restoration**: One-click restore from any backup
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
- **Access management**: Control who can view or manage specific databases with role-based permissions
- **Audit logs**: Track all system activities and changes made by users
- **User roles**: Assign viewer, member, admin or owner roles within workspaces
### 🐳 **Self-Hosted & Secure**
- **Docker-based**: Easy deployment and management
- **Privacy-first**: All your data stays on your infrastructure
- **Open source**: Apache 2.0 licensed, inspect every line of code
### 📦 Installation
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
You have three ways to install Postgresus:
@@ -118,8 +125,6 @@ This single command will:
Create a `docker-compose.yml` file with the following configuration:
```yaml
version: "3"
services:
postgresus:
container_name: postgresus
@@ -149,14 +154,16 @@ docker compose up -d
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
### 🔑 Resetting Admin Password
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
If you need to reset the admin password, you can use the built-in password reset command:
If you need to reset the password, you can use the built-in password reset command:
```bash
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123"
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123" --email="admin"
```
Replace `admin` with the actual email address of the user whose password you want to reset.
---
## 📝 License
@@ -167,4 +174,4 @@ This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENS
## 🤝 Contributing
Contributions are welcome! Read [contributing guide](contribute/readme.md) for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
Contributions are welcome! Read <a href="https://postgresus.com/contributing">contributing guide</a> for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 791 KiB

After

Width:  |  Height:  |  Size: 914 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 12 KiB

After

Width:  |  Height:  |  Size: 23 KiB

View File

@@ -1,14 +1,16 @@
---
description:
globs:
description:
globs:
alwaysApply: true
---
1. When we write controller:
- we combine all routes to single controller
- names them as .WhatWeDo (not "handlers") concept
2. We use gin and *gin.Context for all routes.
Example:
2. We use gin and \*gin.Context for all routes.
Example:
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
@@ -17,24 +19,26 @@ func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
package audit_logs
import (
"net/http"
"net/http"
user_models "logbull/internal/features/users/models"
user_models "postgresus/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
auditLogService \*AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
@@ -52,29 +56,30 @@ func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
@@ -94,34 +99,35 @@ func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
user, isOk := ctx.MustGet("user").(\*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -1,16 +1,18 @@
---
alwaysApply: false
---
This is example of CRUD:
------ backend/internal/features/audit_logs/controller.go ------
``````
```
package audit_logs
import (
"net/http"
user_models "logbull/internal/features/users/models"
user_models "postgresus/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -117,9 +119,11 @@ func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
ctx.JSON(http.StatusOK, response)
}
``````
```
------ backend/internal/features/audit_logs/controller_test.go ------
``````
```
package audit_logs
import (
@@ -128,12 +132,12 @@ import (
"testing"
"time"
user_enums "logbull/internal/features/users/enums"
users_middleware "logbull/internal/features/users/middleware"
users_services "logbull/internal/features/users/services"
users_testing "logbull/internal/features/users/testing"
"logbull/internal/storage"
test_utils "logbull/internal/util/testing"
user_enums "postgresus/internal/features/users/enums"
users_middleware "postgresus/internal/features/users/middleware"
users_services "postgresus/internal/features/users/services"
users_testing "postgresus/internal/features/users/testing"
"postgresus/internal/storage"
test_utils "postgresus/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -256,14 +260,16 @@ func createRouter() *gin.Engine {
return router
}
``````
```
------ backend/internal/features/audit_logs/di.go ------
``````
```
package audit_logs
import (
users_services "logbull/internal/features/users/services"
"logbull/internal/util/logger"
users_services "postgresus/internal/features/users/services"
"postgresus/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
@@ -289,9 +295,11 @@ func SetupDependencies() {
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}
``````
```
------ backend/internal/features/audit_logs/dto.go ------
``````
```
package audit_logs
import "time"
@@ -309,9 +317,11 @@ type GetAuditLogsResponse struct {
Offset int `json:"offset"`
}
``````
```
------ backend/internal/features/audit_logs/models.go ------
``````
```
package audit_logs
import (
@@ -332,13 +342,15 @@ func (AuditLog) TableName() string {
return "audit_logs"
}
``````
```
------ backend/internal/features/audit_logs/repository.go ------
``````
```
package audit_logs
import (
"logbull/internal/storage"
"postgresus/internal/storage"
"time"
"github.com/google/uuid"
@@ -429,9 +441,11 @@ func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
return count, err
}
``````
```
------ backend/internal/features/audit_logs/service.go ------
``````
```
package audit_logs
import (
@@ -439,8 +453,8 @@ import (
"log/slog"
"time"
user_enums "logbull/internal/features/users/enums"
user_models "logbull/internal/features/users/models"
user_enums "postgresus/internal/features/users/enums"
user_models "postgresus/internal/features/users/models"
"github.com/google/uuid"
)
@@ -560,17 +574,19 @@ func (s *AuditLogService) GetProjectAuditLogs(
}, nil
}
``````
```
------ backend/internal/features/audit_logs/service_test.go ------
``````
```
package audit_logs
import (
"testing"
"time"
user_enums "logbull/internal/features/users/enums"
users_testing "logbull/internal/features/users/testing"
user_enums "postgresus/internal/features/users/enums"
users_testing "postgresus/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -652,4 +668,4 @@ func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt ti
db.Create(log)
}
``````
```

View File

@@ -17,6 +17,7 @@ TEST_GOOGLE_DRIVE_CLIENT_ID=
TEST_GOOGLE_DRIVE_CLIENT_SECRET=
TEST_GOOGLE_DRIVE_TOKEN_JSON="{\"access_token\":\"ya29..."
# testing DBs
TEST_POSTGRES_12_PORT=5000
TEST_POSTGRES_13_PORT=5001
TEST_POSTGRES_14_PORT=5002
TEST_POSTGRES_15_PORT=5003
@@ -27,4 +28,7 @@ TEST_POSTGRES_18_PORT=5006
TEST_MINIO_PORT=9000
TEST_MINIO_CONSOLE_PORT=9001
# testing NAS
TEST_NAS_PORT=7006
TEST_NAS_PORT=7006
# testing Telegram
TEST_TELEGRAM_BOT_TOKEN=
TEST_TELEGRAM_CHAT_ID=

3
backend/.gitignore vendored
View File

@@ -12,4 +12,5 @@ swagger/swagger.yaml
postgresus-backend.exe
ui/build/*
pgdata-for-restore/
temp/
temp/
cmd.exe

View File

@@ -13,7 +13,7 @@ import (
"time"
"postgresus-backend/internal/config"
"postgresus-backend/internal/downdetect"
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
@@ -24,7 +24,10 @@ import (
"postgresus-backend/internal/features/restores"
"postgresus-backend/internal/features/storages"
system_healthcheck "postgresus-backend/internal/features/system/healthcheck"
"postgresus-backend/internal/features/users"
users_controllers "postgresus-backend/internal/features/users/controllers"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
env_utils "postgresus-backend/internal/util/env"
files_utils "postgresus-backend/internal/util/files"
"postgresus-backend/internal/util/logger"
@@ -61,13 +64,14 @@ func main() {
os.Exit(1)
}
// Handle password reset if flag is provided
newPassword := flag.String("new-password", "", "Set a new password for the user")
flag.Parse()
if *newPassword != "" {
resetPassword(*newPassword, log)
err = users_services.GetUserService().CreateInitialAdmin()
if err != nil {
log.Error("Failed to create initial admin", "error", err)
os.Exit(1)
}
handlePasswordReset(log)
go generateSwaggerDocs(log)
gin.SetMode(gin.ReleaseMode)
@@ -91,11 +95,33 @@ func main() {
startServerWithGracefulShutdown(log, ginApp)
}
func resetPassword(newPassword string, log *slog.Logger) {
func handlePasswordReset(log *slog.Logger) {
audit_logs.SetupDependencies()
newPassword := flag.String("new-password", "", "Set a new password for the user")
email := flag.String("email", "", "Email of the user to reset password")
flag.Parse()
if *newPassword == "" {
return
}
log.Info("Found reset password command - reseting password...")
if *email == "" {
log.Info("No email provided, please provide an email via --email=\"some@email.com\" flag")
os.Exit(1)
}
resetPassword(*email, *newPassword, log)
}
func resetPassword(email string, newPassword string, log *slog.Logger) {
log.Info("Resetting password...")
userService := users.GetUserService()
err := userService.ChangePassword(newPassword)
userService := users_services.GetUserService()
err := userService.ChangeUserPasswordByEmail(email, newPassword)
if err != nil {
log.Error("Failed to reset password", "error", err)
os.Exit(1)
@@ -146,37 +172,44 @@ func setUpRoutes(r *gin.Engine) {
// Mount Swagger UI
v1.GET("/docs/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
downdetectContoller := downdetect.GetDowndetectController()
userController := users.GetUserController()
notifierController := notifiers.GetNotifierController()
storageController := storages.GetStorageController()
databaseController := databases.GetDatabaseController()
backupController := backups.GetBackupController()
restoreController := restores.GetRestoreController()
healthcheckController := system_healthcheck.GetHealthcheckController()
healthcheckConfigController := healthcheck_config.GetHealthcheckConfigController()
healthcheckAttemptController := healthcheck_attempt.GetHealthcheckAttemptController()
diskController := disk.GetDiskController()
backupConfigController := backups_config.GetBackupConfigController()
downdetectContoller.RegisterRoutes(v1)
// Public routes (only user auth routes and healthcheck should be public)
userController := users_controllers.GetUserController()
userController.RegisterRoutes(v1)
notifierController.RegisterRoutes(v1)
storageController.RegisterRoutes(v1)
databaseController.RegisterRoutes(v1)
backupController.RegisterRoutes(v1)
restoreController.RegisterRoutes(v1)
healthcheckController.RegisterRoutes(v1)
diskController.RegisterRoutes(v1)
healthcheckConfigController.RegisterRoutes(v1)
healthcheckAttemptController.RegisterRoutes(v1)
backupConfigController.RegisterRoutes(v1)
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
// Setup auth middleware
userService := users_services.GetUserService()
authMiddleware := users_middleware.AuthMiddleware(userService)
// Protected routes
protected := v1.Group("")
protected.Use(authMiddleware)
userController.RegisterProtectedRoutes(protected)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected)
workspaces_controllers.GetMembershipController().RegisterRoutes(protected)
disk.GetDiskController().RegisterRoutes(protected)
notifiers.GetNotifierController().RegisterRoutes(protected)
storages.GetStorageController().RegisterRoutes(protected)
databases.GetDatabaseController().RegisterRoutes(protected)
backups.GetBackupController().RegisterRoutes(protected)
restores.GetRestoreController().RegisterRoutes(protected)
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
backups_config.GetBackupConfigController().RegisterRoutes(protected)
audit_logs.GetAuditLogController().RegisterRoutes(protected)
users_controllers.GetManagementController().RegisterRoutes(protected)
users_controllers.GetSettingsController().RegisterRoutes(protected)
}
func setUpDependencies() {
databases.SetupDependencies()
backups.SetupDependencies()
restores.SetupDependencies()
healthcheck_config.SetupDependencies()
audit_logs.SetupDependencies()
notifiers.SetupDependencies()
storages.SetupDependencies()
}
func runBackgroundTasks(log *slog.Logger) {

View File

@@ -32,6 +32,17 @@ services:
command: server /data --console-address ":9001"
# Test PostgreSQL containers
test-postgres-12:
image: postgres:12
ports:
- "${TEST_POSTGRES_12_PORT}:5432"
environment:
- POSTGRES_DB=testdb
- POSTGRES_USER=testuser
- POSTGRES_PASSWORD=testpassword
container_name: test-postgres-12
shm_size: 1gb
test-postgres-13:
image: postgres:13
ports:

View File

@@ -33,6 +33,7 @@ type EnvVariables struct {
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
TestPostgres12Port string `env:"TEST_POSTGRES_12_PORT"`
TestPostgres13Port string `env:"TEST_POSTGRES_13_PORT"`
TestPostgres14Port string `env:"TEST_POSTGRES_14_PORT"`
TestPostgres15Port string `env:"TEST_POSTGRES_15_PORT"`
@@ -44,6 +45,16 @@ type EnvVariables struct {
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
TestNASPort string `env:"TEST_NAS_PORT"`
// oauth
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
// testing Telegram
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
}
var (
@@ -135,6 +146,10 @@ func loadEnvVariables() {
env.TempFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "temp")
if env.IsTesting {
if env.TestPostgres12Port == "" {
log.Error("TEST_POSTGRES_12_PORT is empty")
os.Exit(1)
}
if env.TestPostgres13Port == "" {
log.Error("TEST_POSTGRES_13_PORT is empty")
os.Exit(1)
@@ -173,6 +188,16 @@ func loadEnvVariables() {
log.Error("TEST_NAS_PORT is empty")
os.Exit(1)
}
if env.TestTelegramBotToken == "" {
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
os.Exit(1)
}
if env.TestTelegramChatID == "" {
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
os.Exit(1)
}
}
log.Info("Environment variables loaded successfully!")

View File

@@ -1,37 +0,0 @@
package downdetect
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
)
type DowndetectController struct {
service *DowndetectService
}
func (c *DowndetectController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/downdetect/is-available", c.IsAvailable)
}
// @Summary Check API availability
// @Description Checks if the API service is available
// @Tags downdetect
// @Accept json
// @Produce json
// @Success 200
// @Failure 500
// @Router /downdetect/api [get]
func (c *DowndetectController) IsAvailable(ctx *gin.Context) {
err := c.service.IsDbAvailable()
if err != nil {
ctx.JSON(
http.StatusInternalServerError,
gin.H{"error": fmt.Sprintf("Database is not available: %v", err)},
)
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "API and DB are available"})
}

View File

@@ -1,10 +0,0 @@
package downdetect
var downdetectService = &DowndetectService{}
var downdetectController = &DowndetectController{
downdetectService,
}
func GetDowndetectController() *DowndetectController {
return downdetectController
}

View File

@@ -1,17 +0,0 @@
package downdetect
import (
"postgresus-backend/internal/storage"
)
type DowndetectService struct {
}
func (s *DowndetectService) IsDbAvailable() error {
err := storage.GetDb().Exec("SELECT 1").Error
if err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,111 @@
package audit_logs
import (
"net/http"
user_models "postgresus-backend/internal/features/users/models"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type AuditLogController struct {
auditLogService *AuditLogService
}
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
// All audit log endpoints require authentication (handled in main.go)
auditRoutes := router.Group("/audit-logs")
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
}
// GetGlobalAuditLogs
// @Summary Get global audit logs (ADMIN only)
// @Description Retrieve all audit logs across the system
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/global [get]
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
if err != nil {
if err.Error() == "only administrators can view global audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}
// GetUserAuditLogs
// @Summary Get user audit logs
// @Description Retrieve audit logs for a specific user
// @Tags audit-logs
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param userId path string true "User ID"
// @Param limit query int false "Limit number of results" default(100)
// @Param offset query int false "Offset for pagination" default(0)
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
// @Success 200 {object} GetAuditLogsResponse
// @Failure 400 {object} map[string]string
// @Failure 401 {object} map[string]string
// @Failure 403 {object} map[string]string
// @Router /audit-logs/users/{userId} [get]
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
user, isOk := ctx.MustGet("user").(*user_models.User)
if !isOk {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
return
}
userIDStr := ctx.Param("userId")
targetUserID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
request := &GetAuditLogsRequest{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
if err != nil {
if err.Error() == "insufficient permissions to view user audit logs" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
return
}
ctx.JSON(http.StatusOK, response)
}

View File

@@ -0,0 +1,154 @@
package audit_logs
import (
"fmt"
"net/http"
"testing"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
workspaceID := uuid.New()
testID := uuid.New().String()
// Create test logs with unique identifiers
userLogMessage := fmt.Sprintf("Test log with user %s", testID)
workspaceLogMessage := fmt.Sprintf("Test log with workspace %s", testID)
standaloneLogMessage := fmt.Sprintf("Test log standalone %s", testID)
createAuditLog(service, userLogMessage, &adminUser.UserID, nil)
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
createAuditLog(service, standaloneLogMessage, nil, nil)
// Test ADMIN can access global logs
var response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
"/api/v1/audit-logs/global?limit=100", "Bearer "+adminUser.Token, http.StatusOK, &response)
// Verify our specific test logs are present
messages := extractMessages(response.AuditLogs)
assert.Contains(t, messages, userLogMessage)
assert.Contains(t, messages, workspaceLogMessage)
assert.Contains(t, messages, standaloneLogMessage)
// Test MEMBER cannot access global logs
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
"Bearer "+memberUser.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
}
func Test_GetUserAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
router := createRouter()
service := GetAuditLogService()
workspaceID := uuid.New()
testID := uuid.New().String()
// Create test logs for different users with unique identifiers
user1FirstMessage := fmt.Sprintf("Test log user1 first %s", testID)
user1SecondMessage := fmt.Sprintf("Test log user1 second %s", testID)
user2FirstMessage := fmt.Sprintf("Test log user2 first %s", testID)
user2SecondMessage := fmt.Sprintf("Test log user2 second %s", testID)
workspaceLogMessage := fmt.Sprintf("Test workspace log %s", testID)
createAuditLog(service, user1FirstMessage, &user1.UserID, nil)
createAuditLog(service, user1SecondMessage, &user1.UserID, &workspaceID)
createAuditLog(service, user2FirstMessage, &user2.UserID, nil)
createAuditLog(service, user2SecondMessage, &user2.UserID, &workspaceID)
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
// Test ADMIN can view any user's logs
var user1Response GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user1.UserID.String()),
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
// Verify user1's specific logs are present
messages := extractMessages(user1Response.AuditLogs)
assert.Contains(t, messages, user1FirstMessage)
assert.Contains(t, messages, user1SecondMessage)
// Count only our test logs for user1
testLogsCount := 0
for _, message := range messages {
if message == user1FirstMessage || message == user1SecondMessage {
testLogsCount++
}
}
assert.Equal(t, 2, testLogsCount)
// Test user can view own logs
var ownLogsResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user2.UserID.String()),
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
// Verify user2's specific logs are present
ownMessages := extractMessages(ownLogsResponse.AuditLogs)
assert.Contains(t, ownMessages, user2FirstMessage)
assert.Contains(t, ownMessages, user2SecondMessage)
// Test user cannot view other user's logs
resp := test_utils.MakeGetRequest(t, router,
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
"Bearer "+user2.Token, http.StatusForbidden)
assert.Contains(t, string(resp.Body), "insufficient permissions")
}
func Test_GetGlobalAuditLogs_WithBeforeDateFilter_ReturnsFilteredLogs(t *testing.T) {
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
router := createRouter()
baseTime := time.Now().UTC()
// Set filter time to 30 minutes ago
beforeTime := baseTime.Add(-30 * time.Minute)
var filteredResponse GetAuditLogsResponse
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/audit-logs/global?beforeDate=%s&limit=1000",
beforeTime.Format(time.RFC3339),
),
"Bearer "+adminUser.Token,
http.StatusOK,
&filteredResponse,
)
// Verify ALL returned logs are older than the filter time
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime),
fmt.Sprintf("Log created at %s should be before filter time %s",
log.CreatedAt.Format(time.RFC3339), beforeTime.Format(time.RFC3339)))
}
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
SetupDependencies()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
return router
}

View File

@@ -0,0 +1,29 @@
package audit_logs
import (
users_services "postgresus-backend/internal/features/users/services"
"postgresus-backend/internal/util/logger"
)
var auditLogRepository = &AuditLogRepository{}
var auditLogService = &AuditLogService{
auditLogRepository: auditLogRepository,
logger: logger.GetLogger(),
}
var auditLogController = &AuditLogController{
auditLogService: auditLogService,
}
func GetAuditLogService() *AuditLogService {
return auditLogService
}
func GetAuditLogController() *AuditLogController {
return auditLogController
}
func SetupDependencies() {
users_services.GetUserService().SetAuditLogWriter(auditLogService)
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
}

View File

@@ -0,0 +1,31 @@
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type GetAuditLogsRequest struct {
Limit int `form:"limit" json:"limit"`
Offset int `form:"offset" json:"offset"`
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
}
type GetAuditLogsResponse struct {
AuditLogs []*AuditLogDTO `json:"auditLogs"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}
type AuditLogDTO struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
UserEmail *string `json:"userEmail" gorm:"column:user_email"`
UserName *string `json:"userName" gorm:"column:user_name"`
WorkspaceName *string `json:"workspaceName" gorm:"column:workspace_name"`
}

View File

@@ -0,0 +1,19 @@
package audit_logs
import (
"time"
"github.com/google/uuid"
)
type AuditLog struct {
ID uuid.UUID `json:"id" gorm:"column:id"`
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
Message string `json:"message" gorm:"column:message"`
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
}
func (AuditLog) TableName() string {
return "audit_logs"
}

View File

@@ -0,0 +1,139 @@
package audit_logs
import (
"postgresus-backend/internal/storage"
"time"
"github.com/google/uuid"
)
type AuditLogRepository struct{}
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
if auditLog.ID == uuid.Nil {
auditLog.ID = uuid.New()
}
return storage.GetDb().Create(auditLog).Error
}
func (r *AuditLogRepository) GetGlobal(
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id`
args := []interface{}{}
if beforeDate != nil {
sql += " WHERE al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByUser(
userID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.user_id = ?`
args := []interface{}{userID}
if beforeDate != nil {
sql += " AND al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) GetByWorkspace(
workspaceID uuid.UUID,
limit, offset int,
beforeDate *time.Time,
) ([]*AuditLogDTO, error) {
var auditLogs = make([]*AuditLogDTO, 0)
sql := `
SELECT
al.id,
al.user_id,
al.workspace_id,
al.message,
al.created_at,
u.email as user_email,
u.name as user_name,
w.name as workspace_name
FROM audit_logs al
LEFT JOIN users u ON al.user_id = u.id
LEFT JOIN workspaces w ON al.workspace_id = w.id
WHERE al.workspace_id = ?`
args := []interface{}{workspaceID}
if beforeDate != nil {
sql += " AND al.created_at < ?"
args = append(args, *beforeDate)
}
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
args = append(args, limit, offset)
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
return auditLogs, err
}
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
var count int64
query := storage.GetDb().Model(&AuditLog{})
if beforeDate != nil {
query = query.Where("created_at < ?", *beforeDate)
}
err := query.Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,137 @@
package audit_logs
import (
"errors"
"log/slog"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
user_models "postgresus-backend/internal/features/users/models"
"github.com/google/uuid"
)
type AuditLogService struct {
auditLogRepository *AuditLogRepository
logger *slog.Logger
}
func (s *AuditLogService) WriteAuditLog(
message string,
userID *uuid.UUID,
workspaceID *uuid.UUID,
) {
auditLog := &AuditLog{
UserID: userID,
WorkspaceID: workspaceID,
Message: message,
CreatedAt: time.Now().UTC(),
}
err := s.auditLogRepository.Create(auditLog)
if err != nil {
s.logger.Error("failed to create audit log", "error", err)
return
}
}
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
return s.auditLogRepository.Create(auditLog)
}
func (s *AuditLogService) GetGlobalAuditLogs(
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
if user.Role != user_enums.UserRoleAdmin {
return nil, errors.New("only administrators can view global audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
if err != nil {
return nil, err
}
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetUserAuditLogs(
targetUserID uuid.UUID,
user *user_models.User,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
// Users can view their own logs, ADMIN can view any user's logs
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
return nil, errors.New("insufficient permissions to view user audit logs")
}
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByUser(
targetUserID,
limit,
offset,
request.BeforeDate,
)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}
func (s *AuditLogService) GetWorkspaceAuditLogs(
workspaceID uuid.UUID,
request *GetAuditLogsRequest,
) (*GetAuditLogsResponse, error) {
limit := request.Limit
if limit <= 0 || limit > 1000 {
limit = 100
}
offset := max(request.Offset, 0)
auditLogs, err := s.auditLogRepository.GetByWorkspace(
workspaceID,
limit,
offset,
request.BeforeDate,
)
if err != nil {
return nil, err
}
return &GetAuditLogsResponse{
AuditLogs: auditLogs,
Total: int64(len(auditLogs)),
Limit: limit,
Offset: offset,
}, nil
}

View File

@@ -0,0 +1,83 @@
package audit_logs
import (
"testing"
"time"
user_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
service := GetAuditLogService()
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
workspace1ID, workspace2ID := uuid.New(), uuid.New()
// Create test logs for workspaces
createAuditLog(service, "Test workspace1 log first", &user1.UserID, &workspace1ID)
createAuditLog(service, "Test workspace1 log second", &user2.UserID, &workspace1ID)
createAuditLog(service, "Test workspace2 log first", &user1.UserID, &workspace2ID)
createAuditLog(service, "Test workspace2 log second", &user2.UserID, &workspace2ID)
createAuditLog(service, "Test no workspace log", &user1.UserID, nil)
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
// Test workspace 1 logs
workspace1Response, err := service.GetWorkspaceAuditLogs(workspace1ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(workspace1Response.AuditLogs))
messages := extractMessages(workspace1Response.AuditLogs)
assert.Contains(t, messages, "Test workspace1 log first")
assert.Contains(t, messages, "Test workspace1 log second")
for _, log := range workspace1Response.AuditLogs {
assert.Equal(t, &workspace1ID, log.WorkspaceID)
}
// Test workspace 2 logs
workspace2Response, err := service.GetWorkspaceAuditLogs(workspace2ID, request)
assert.NoError(t, err)
assert.Equal(t, 2, len(workspace2Response.AuditLogs))
messages2 := extractMessages(workspace2Response.AuditLogs)
assert.Contains(t, messages2, "Test workspace2 log first")
assert.Contains(t, messages2, "Test workspace2 log second")
// Test pagination
limitedResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
&GetAuditLogsRequest{Limit: 1, Offset: 0})
assert.NoError(t, err)
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
assert.Equal(t, 1, limitedResponse.Limit)
// Test beforeDate filter
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
filteredResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
assert.NoError(t, err)
for _, log := range filteredResponse.AuditLogs {
assert.True(t, log.CreatedAt.Before(beforeTime))
assert.NotNil(t, log.UserEmail, "User email should be present for logs with user_id")
assert.NotNil(
t,
log.WorkspaceName,
"Workspace name should be present for logs with workspace_id",
)
}
}
func createAuditLog(service *AuditLogService, message string, userID, workspaceID *uuid.UUID) {
service.WriteAuditLog(message, userID, workspaceID)
}
func extractMessages(logs []*AuditLogDTO) []string {
messages := make([]string, len(logs))
for i, log := range logs {
messages[i] = log.Message
}
return messages
}

View File

@@ -6,7 +6,9 @@ import (
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/period"
"testing"
"time"
@@ -16,10 +18,12 @@ import (
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -40,11 +44,8 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
// add old backup
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -67,16 +68,20 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -97,11 +102,8 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
// add recent backup (1 hour ago)
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusCompleted,
@@ -124,16 +126,20 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries disabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -157,11 +163,8 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -185,16 +188,20 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -218,11 +225,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
// add failed backup
failMessage := "backup failed"
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -246,16 +250,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
// setup data
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
// Enable backups for the database with retries enabled
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
@@ -280,11 +288,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
for i := 0; i < 3; i++ {
backupRepository.Save(&Backup{
Database: database,
DatabaseID: database.ID,
Storage: storage,
StorageID: storage.ID,
StorageID: storage.ID,
Status: BackupStatusFailed,
FailMessage: &failMessage,
@@ -309,6 +314,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
}
databases.RemoveTestDatabase(database)
storages.RemoveTestStorage(storage.ID)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}

View File

@@ -0,0 +1,47 @@
package backups
import (
"context"
"errors"
"sync"
"github.com/google/uuid"
)
type BackupContextManager struct {
mu sync.RWMutex
cancelFuncs map[uuid.UUID]context.CancelFunc
}
func NewBackupContextManager() *BackupContextManager {
return &BackupContextManager{
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
}
}
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
m.cancelFuncs[backupID] = cancelFunc
}
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
m.mu.Lock()
defer m.mu.Unlock()
cancelFunc, exists := m.cancelFuncs[backupID]
if !exists {
return errors.New("backup is not in progress or already completed")
}
cancelFunc()
delete(m.cancelFuncs, backupID)
return nil
}
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.cancelFuncs, backupID)
}

View File

@@ -4,7 +4,7 @@ import (
"fmt"
"io"
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -12,7 +12,6 @@ import (
type BackupController struct {
backupService *BackupService
userService *users.UserService
}
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
@@ -20,51 +19,48 @@ func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/backups", c.MakeBackup)
router.GET("/backups/:id/file", c.GetFile)
router.DELETE("/backups/:id", c.DeleteBackup)
router.POST("/backups/:id/cancel", c.CancelBackup)
}
// GetBackups
// @Summary Get backups for a database
// @Description Get all backups for the specified database
// @Description Get paginated backups for the specified database
// @Tags backups
// @Produce json
// @Param database_id query string true "Database ID"
// @Success 200 {array} Backup
// @Param limit query int false "Number of items per page" default(10)
// @Param offset query int false "Offset for pagination" default(0)
// @Success 200 {object} GetBackupsResponse
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups [get]
func (c *BackupController) GetBackups(ctx *gin.Context) {
databaseIDStr := ctx.Query("database_id")
if databaseIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "database_id query parameter is required"})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
databaseID, err := uuid.Parse(databaseIDStr)
var request GetBackupsRequest
if err := ctx.ShouldBindQuery(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
databaseID, err := uuid.Parse(request.DatabaseID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
backups, err := c.backupService.GetBackups(user, databaseID)
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, backups)
ctx.JSON(http.StatusOK, response)
}
// MakeBackup
@@ -80,24 +76,18 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
// @Failure 500
// @Router /backups [post]
func (c *BackupController) MakeBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request MakeBackupRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -117,24 +107,18 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
// @Failure 500
// @Router /backups/{id} [delete]
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.backupService.DeleteBackup(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -143,6 +127,37 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
ctx.Status(http.StatusNoContent)
}
// CancelBackup
// @Summary Cancel an in-progress backup
// @Description Cancel a backup that is currently in progress
// @Tags backups
// @Param id path string true "Backup ID"
// @Success 204
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /backups/{id}/cancel [post]
func (c *BackupController) CancelBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
if err := c.backupService.CancelBackup(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.Status(http.StatusNoContent)
}
// GetFile
// @Summary Download a backup file
// @Description Download the backup file for the specified backup
@@ -154,24 +169,18 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
// @Failure 500
// @Router /backups/{id}/file [get]
func (c *BackupController) GetFile(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
fileReader, err := c.backupService.GetBackupFile(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -179,19 +188,16 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
}
defer func() {
if err := fileReader.Close(); err != nil {
// Log the error but don't interrupt the response
fmt.Printf("Error closing file reader: %v\n", err)
}
}()
// Set headers for file download
ctx.Header("Content-Type", "application/octet-stream")
ctx.Header(
"Content-Disposition",
fmt.Sprintf("attachment; filename=\"backup_%s.dump\"", id.String()),
)
// Stream the file content
_, err = io.Copy(ctx.Writer, fileReader)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})

View File

@@ -0,0 +1,708 @@
package backups
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
audit_logs "postgresus-backend/internal/features/audit_logs"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/storages"
local_storage "postgresus-backend/internal/features/storages/models/local"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func Test_GetBackups_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace viewer can get backups",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get backups",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get backups",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can get backups",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups?database_id=%s", database.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if tt.expectSuccess {
var response GetBackupsResponse
err := json.Unmarshal(testResp.Body, &response)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(response.Backups), 1)
assert.GreaterOrEqual(t, response.Total, int64(1))
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can create backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot create backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can create backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
request := MakeBackupRequest{DatabaseID: database.ID}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
if tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "backup started successfully")
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CreateBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
enableBackupForDatabase(database.ID)
request := MakeBackupRequest{DatabaseID: database.ID}
test_utils.MakePostRequest(
t,
router,
"/api/v1/backups",
"Bearer "+owner.Token,
request,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup manually initiated") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup creation not found")
}
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace member can delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace viewer cannot delete backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "non-member cannot delete backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can delete backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
} else {
userService := users_services.GetUserService()
ownerUser, err := userService.GetUserFromToken(owner.Token)
assert.NoError(t, err)
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
assert.NoError(t, err)
assert.Equal(t, 0, len(response.Backups))
}
})
}
}
func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusNoContent,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup deleted") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup deletion not found")
}
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace viewer can download backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can download backup",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot download backup",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can download backup",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil {
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup file downloaded") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for backup download not found")
}
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
assert.NoError(t, err)
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
assert.NoError(t, err)
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storage.ID,
Status: BackupStatusInProgress,
BackupSizeMb: 0,
BackupDurationMs: 0,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
err = repo.Save(backup)
assert.NoError(t, err)
// Register a cancellable context for the backup
GetBackupService().backupContextMgr.RegisterBackup(backup.ID, func() {})
resp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/backups/%s/cancel", backup.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusNoContent,
)
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
// Verify audit log was created
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
userService := users_services.GetUserService()
adminUser, err := userService.GetUserFromToken(admin.Token)
assert.NoError(t, err)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetGlobalAuditLogs(
adminUser,
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
)
assert.NoError(t, err)
foundCancelLog := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Backup cancelled") &&
strings.Contains(log.Message, database.Name) {
foundCancelLog = true
break
}
}
assert.True(t, foundCancelLog, "Cancel audit log should be created")
}
func createTestRouter() *gin.Engine {
return CreateTestRouter()
}
func createTestDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
Name: name,
WorkspaceID: &workspaceID,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,
Type: storages.StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
}
repo := &storages.StorageRepository{}
storage, err := repo.Save(storage)
if err != nil {
panic(err)
}
return storage
}
func enableBackupForDatabase(databaseID uuid.UUID) {
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(databaseID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
}
func createTestDatabaseWithBackups(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
backup := createTestBackup(database, owner)
return database, backup
}
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
// Create a dummy backup file for testing download functionality
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}

View File

@@ -1,17 +1,21 @@
package backups
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups/usecases"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
"time"
)
var backupRepository = &BackupRepository{}
var backupContextManager = NewBackupContextManager()
var backupService = &BackupService{
databases.GetDatabaseService(),
storages.GetStorageService(),
@@ -22,6 +26,9 @@ var backupService = &BackupService{
usecases.GetCreateBackupUsecase(),
logger.GetLogger(),
[]BackupRemoveListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
backupContextManager,
}
var backupBackgroundService = &BackupBackgroundService{
@@ -35,7 +42,6 @@ var backupBackgroundService = &BackupBackgroundService{
var backupController = &BackupController{
backupService,
users.GetUserService(),
}
func SetupDependencies() {

View File

@@ -0,0 +1,14 @@
package backups
type GetBackupsRequest struct {
DatabaseID string `form:"database_id" binding:"required"`
Limit int `form:"limit"`
Offset int `form:"offset"`
}
type GetBackupsResponse struct {
Backups []*Backup `json:"backups"`
Total int64 `json:"total"`
Limit int `json:"limit"`
Offset int `json:"offset"`
}

View File

@@ -6,4 +6,5 @@ const (
BackupStatusInProgress BackupStatus = "IN_PROGRESS"
BackupStatusCompleted BackupStatus = "COMPLETED"
BackupStatusFailed BackupStatus = "FAILED"
BackupStatusCanceled BackupStatus = "CANCELED"
)

View File

@@ -1,6 +1,8 @@
package backups
import (
"context"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
@@ -19,6 +21,7 @@ type NotificationSender interface {
type CreateBackupUsecase interface {
Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,

View File

@@ -1,8 +1,6 @@
package backups
import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/storages"
"time"
"github.com/google/uuid"
@@ -11,11 +9,8 @@ import (
type Backup struct {
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
Database *databases.Database `json:"database" gorm:"foreignKey:DatabaseID"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
Status BackupStatus `json:"status" gorm:"column:status;not null"`
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`

View File

@@ -13,18 +13,20 @@ import (
type BackupRepository struct{}
func (r *BackupRepository) Save(backup *Backup) error {
if backup.DatabaseID == uuid.Nil || backup.StorageID == uuid.Nil {
return errors.New("database ID and storage ID are required")
}
db := storage.GetDb()
isNew := backup.ID == uuid.Nil
if isNew {
backup.ID = uuid.New()
return db.Create(backup).
Omit("Database", "Storage").
Error
}
return db.Save(backup).
Omit("Database", "Storage").
Error
}
@@ -33,8 +35,6 @@ func (r *BackupRepository) FindByDatabaseID(databaseID uuid.UUID) ([]*Backup, er
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -56,8 +56,6 @@ func (r *BackupRepository) FindByDatabaseIDWithLimit(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).
@@ -73,8 +71,6 @@ func (r *BackupRepository) FindByStorageID(storageID uuid.UUID) ([]*Backup, erro
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ?", storageID).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -89,8 +85,6 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ?", databaseID).
Order("created_at DESC").
First(&backup).Error; err != nil {
@@ -109,8 +103,6 @@ func (r *BackupRepository) FindByID(id uuid.UUID) (*Backup, error) {
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("id = ?", id).
First(&backup).Error; err != nil {
return nil, err
@@ -124,8 +116,6 @@ func (r *BackupRepository) FindByStatus(status BackupStatus) ([]*Backup, error)
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("status = ?", status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -143,8 +133,6 @@ func (r *BackupRepository) FindByStorageIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("storage_id = ? AND status = ?", storageID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -162,8 +150,6 @@ func (r *BackupRepository) FindByDatabaseIdAndStatus(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND status = ?", databaseID, status).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -185,8 +171,6 @@ func (r *BackupRepository) FindBackupsBeforeDate(
if err := storage.
GetDb().
Preload("Database").
Preload("Storage").
Where("database_id = ? AND created_at < ?", databaseID, date).
Order("created_at DESC").
Find(&backups).Error; err != nil {
@@ -195,3 +179,36 @@ func (r *BackupRepository) FindBackupsBeforeDate(
return backups, nil
}
func (r *BackupRepository) FindByDatabaseIDWithPagination(
databaseID uuid.UUID,
limit, offset int,
) ([]*Backup, error) {
var backups []*Backup
if err := storage.
GetDb().
Where("database_id = ?", databaseID).
Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&backups).Error; err != nil {
return nil, err
}
return backups, nil
}
func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
var count int64
if err := storage.
GetDb().
Model(&Backup{}).
Where("database_id = ?", databaseID).
Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}

View File

@@ -1,16 +1,20 @@
package backups
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"slices"
"strings"
"time"
"github.com/google/uuid"
@@ -29,6 +33,10 @@ type BackupService struct {
logger *slog.Logger
backupRemoveListeners []BackupRemoveListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
backupContextMgr *BackupContextManager
}
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
@@ -62,34 +70,74 @@ func (s *BackupService) MakeBackupWithAuth(
return err
}
if database.UserID != user.ID {
return errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot create backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to create backup for this database")
}
go s.MakeBackup(databaseID, true)
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}
func (s *BackupService) GetBackups(
user *users_models.User,
databaseID uuid.UUID,
) ([]*Backup, error) {
limit, offset int,
) (*GetBackupsResponse, error) {
database, err := s.databaseService.GetDatabaseByID(databaseID)
if err != nil {
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot get backups for database without workspace")
}
backups, err := s.backupRepository.FindByDatabaseID(databaseID)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access backups for this database")
}
if limit <= 0 {
limit = 10
}
if offset < 0 {
offset = 0
}
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
if err != nil {
return nil, err
}
return backups, nil
total, err := s.backupRepository.CountByDatabaseID(databaseID)
if err != nil {
return nil, err
}
return &GetBackupsResponse{
Backups: backups,
Total: total,
Limit: limit,
Offset: offset,
}, nil
}
func (s *BackupService) DeleteBackup(
@@ -101,14 +149,37 @@ func (s *BackupService) DeleteBackup(
return err
}
if backup.Database.UserID != user.ID {
return errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot delete backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to delete backup for this database")
}
if backup.Status == BackupStatusInProgress {
return errors.New("backup is in progress")
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup deleted for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return s.deleteBackup(backup)
}
@@ -154,10 +225,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
backup := &Backup{
DatabaseID: databaseID,
Database: database,
StorageID: storage.ID,
Storage: storage,
StorageID: storage.ID,
Status: BackupStatusInProgress,
@@ -184,7 +252,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
}
}
ctx, cancel := context.WithCancel(context.Background())
s.backupContextMgr.RegisterBackup(backup.ID, cancel)
defer s.backupContextMgr.UnregisterBackup(backup.ID)
err = s.createBackupUseCase.Execute(
ctx,
backup.ID,
backupConfig,
database,
@@ -193,6 +266,34 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
)
if err != nil {
errMsg := err.Error()
// Check if backup was cancelled (not due to shutdown)
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
backup.Status = BackupStatusCanceled
backup.BackupDurationMs = time.Since(start).Milliseconds()
backup.BackupSizeMb = 0
if err := s.backupRepository.Save(backup); err != nil {
s.logger.Error("Failed to save cancelled backup", "error", err)
}
// Delete partial backup from storage
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
if storageErr == nil {
if deleteErr := storage.DeleteFile(backup.ID); deleteErr != nil {
s.logger.Error(
"Failed to delete partial backup file",
"backupId",
backup.ID,
"error",
deleteErr,
)
}
}
return
}
backup.FailMessage = &errMsg
backup.Status = BackupStatusFailed
backup.BackupDurationMs = time.Since(start).Milliseconds()
@@ -319,6 +420,53 @@ func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
return s.backupRepository.FindByID(backupID)
}
func (s *BackupService) CancelBackup(
user *users_models.User,
backupID uuid.UUID,
) error {
backup, err := s.backupRepository.FindByID(backupID)
if err != nil {
return err
}
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot cancel backup for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to cancel backup for this database")
}
if backup.Status != BackupStatusInProgress {
return errors.New("backup is not in progress")
}
if err := s.backupContextMgr.CancelBackup(backupID); err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup cancelled for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return nil
}
func (s *BackupService) GetBackupFile(
user *users_models.User,
backupID uuid.UUID,
@@ -328,8 +476,24 @@ func (s *BackupService) GetBackupFile(
return nil, err
}
if backup.Database.UserID != user.ID {
return nil, errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot download backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to download backup for this database")
}
storage, err := s.storageService.GetStorageByID(backup.StorageID)
@@ -337,6 +501,16 @@ func (s *BackupService) GetBackupFile(
return nil, err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Backup file downloaded for database: %s (ID: %s)",
database.Name,
backupID.String(),
),
&user.ID,
database.WorkspaceID,
)
return storage.GetFile(backup.ID)
}
@@ -354,7 +528,10 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
err = storage.DeleteFile(backup.ID)
if err != nil {
return err
// we do not return error here, because sometimes clean up performed
// before unavailable storage removal or change - therefore we should
// proceed even in case of error
s.logger.Error("Failed to delete backup file", "error", err)
}
return s.backupRepository.DeleteByID(backup.ID)

View File

@@ -1,15 +1,19 @@
package backups
import (
"context"
"errors"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/logger"
"strings"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
@@ -17,15 +21,27 @@ import (
)
func Test_BackupExecuted_NotificationSent(t *testing.T) {
user := users.GetTestUser()
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
router := CreateTestRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
defer storages.RemoveTestStorage(storage.ID)
defer notifiers.RemoveTestNotifier(notifier)
defer databases.RemoveTestDatabase(database)
defer func() {
// cleanup backups first
backups, _ := backupRepository.FindByDatabaseID(database.ID)
for _, backup := range backups {
backupRepository.DeleteByID(backup.ID)
}
databases.RemoveTestDatabase(database)
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
notifiers.RemoveTestNotifier(notifier)
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}()
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
mockNotificationSender := &MockNotificationSender{}
@@ -39,6 +55,9 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
&CreateFailedBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
NewBackupContextManager(),
}
// Set up expectations
@@ -82,6 +101,9 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
NewBackupContextManager(),
}
backupService.MakeBackup(database.ID, true)
@@ -102,6 +124,9 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
&CreateSuccessBackupUsecase{},
logger.GetLogger(),
[]BackupRemoveListener{},
nil, // workspaceService
nil, // auditLogService
NewBackupContextManager(),
}
// capture arguments
@@ -137,6 +162,7 @@ type CreateFailedBackupUsecase struct {
}
func (uc *CreateFailedBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -153,6 +179,7 @@ type CreateSuccessBackupUsecase struct {
}
func (uc *CreateSuccessBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,

View File

@@ -0,0 +1,20 @@
package backups
import (
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"github.com/gin-gonic/gin"
)
func CreateTestRouter() *gin.Engine {
return workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
GetBackupController(),
)
}

View File

@@ -1,6 +1,7 @@
package usecases
import (
"context"
"errors"
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
backups_config "postgresus-backend/internal/features/backups/config"
@@ -16,6 +17,7 @@ type CreateBackupUsecase struct {
// Execute creates a backup of the database and returns the backup size in MB
func (uc *CreateBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
database *databases.Database,
@@ -26,6 +28,7 @@ func (uc *CreateBackupUsecase) Execute(
) error {
if database.Type == databases.DatabaseTypePostgres {
return uc.CreatePostgresqlBackupUsecase.Execute(
ctx,
backupID,
backupConfig,
database,

View File

@@ -29,6 +29,7 @@ type CreatePostgresqlBackupUsecase struct {
// Execute creates a backup of the database
func (uc *CreatePostgresqlBackupUsecase) Execute(
ctx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
db *databases.Database,
@@ -69,10 +70,10 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
"--verbose", // Add verbose output to help with debugging
}
// Use zstd compression level 5 for PostgreSQL 15+ (better compression and speed)
// Fall back to gzip compression level 5 for older versions
if pg.Version == tools.PostgresqlVersion13 || pg.Version == tools.PostgresqlVersion14 ||
pg.Version == tools.PostgresqlVersion15 {
// Use zstd compression level 5 for PostgreSQL 16+ (better compression and speed)
// Fall back to gzip compression level 5 for older versions (12-15)
if pg.Version == tools.PostgresqlVersion12 || pg.Version == tools.PostgresqlVersion13 ||
pg.Version == tools.PostgresqlVersion14 || pg.Version == tools.PostgresqlVersion15 {
args = append(args, "-Z", "5")
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
} else {
@@ -81,6 +82,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
}
return uc.streamToStorage(
ctx,
backupID,
backupConfig,
tools.GetPostgresqlExecutable(
@@ -99,6 +101,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
// streamToStorage streams pg_dump output directly to storage
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
parentCtx context.Context,
backupID uuid.UUID,
backupConfig *backups_config.BackupConfig,
pgBin string,
@@ -112,7 +115,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
// if backup not fit into 23 hours, Postgresus
// seems not to work for such database size
ctx, cancel := context.WithTimeout(context.Background(), 23*time.Hour)
ctx, cancel := context.WithTimeout(parentCtx, 23*time.Hour)
defer cancel()
// Monitor for shutdown and cancel context if needed
@@ -272,8 +275,9 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
bytesWritten := <-bytesWrittenCh
waitErr := cmd.Wait()
// Check for shutdown before finalizing
if config.IsShouldShutdown() {
// Check for shutdown or cancellation before finalizing
select {
case <-ctx.Done():
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
if err := pipeWriter.Close(); err != nil {
uc.logger.Error("Failed to close counting writer", "error", err)
@@ -281,7 +285,12 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
}
<-saveErrCh // Wait for storage to finish
return fmt.Errorf("backup cancelled due to shutdown")
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
}
// Close the pipe writer to signal end of data
@@ -303,8 +312,13 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
switch {
case waitErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
}
// Enhanced error handling for PostgreSQL connection and SSL issues
@@ -402,14 +416,24 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
return errors.New(errorMsg)
case copyErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
}
return fmt.Errorf("copy to storage: %w", copyErr)
case saveErr != nil:
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
select {
case <-ctx.Done():
if config.IsShouldShutdown() {
return fmt.Errorf("backup cancelled due to shutdown")
}
return fmt.Errorf("backup cancelled")
default:
}
return fmt.Errorf("save to storage: %w", saveErr)

View File

@@ -2,7 +2,7 @@ package backups_config
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type BackupConfigController struct {
backupConfigService *BackupConfigService
userService *users.UserService
}
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
@@ -32,24 +31,18 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 500
// @Router /backup-configs/save [post]
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var requestDTO BackupConfig
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
// make sure we rely on full .Storage object
requestDTO.StorageID = nil
@@ -74,30 +67,18 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
// @Failure 404
// @Router /backup-configs/database/{id} [get]
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
_, err = c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
backupConfig, err := c.backupConfigService.GetBackupConfigByDbIdWithAuth(user, id)
if err != nil {
ctx.JSON(http.StatusNotFound, gin.H{"error": "backup configuration not found"})
@@ -119,24 +100,18 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
// @Failure 500
// @Router /backup-configs/storage/{id}/is-using [get]
func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
isUsing, err := c.backupConfigService.IsStorageUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -0,0 +1,413 @@
package backups_config
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/storages"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"postgresus-backend/internal/util/period"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetBackupConfigController(),
)
return router
}
func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer cannot save backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can save backup config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
}
var response BackupConfig
testResp := test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, 2, response.CpuCount)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
timeOfDay := "04:00"
request := BackupConfig{
DatabaseID: database.ID,
IsBackupsEnabled: true,
StorePeriod: period.PeriodWeek,
BackupInterval: &intervals.Interval{
Interval: intervals.IntervalDaily,
TimeOfDay: &timeOfDay,
},
SendNotificationsOn: []BackupNotificationType{
NotificationBackupFailed,
},
CpuCount: 2,
IsRetryIfFailed: true,
MaxFailedTriesCount: 3,
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/backup-configs/save",
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get backup config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get backup config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get backup config",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusNotFound,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
var response BackupConfig
testResp := test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, database.ID, response.DatabaseID)
assert.NotNil(t, response.BackupInterval)
} else {
assert.Contains(t, string(testResp.Body), "backup configuration not found")
}
})
}
}
func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response BackupConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/database/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.False(t, response.IsBackupsEnabled)
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
assert.Equal(t, 1, response.CpuCount)
assert.True(t, response.IsRetryIfFailed)
assert.Equal(t, 3, response.MaxFailedTriesCount)
assert.NotNil(t, response.BackupInterval)
}
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isStorageOwner bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "storage owner can check storage usage",
isStorageOwner: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-storage-owner cannot check storage usage",
isStorageOwner: false,
expectSuccess: false,
expectedStatusCode: http.StatusInternalServerError,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
storageOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace(
"Test Workspace",
storageOwner,
router,
)
storage := createTestStorage(workspace.ID)
var testUserToken string
if tt.isStorageOwner {
testUserToken = storageOwner.Token
} else {
otherUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = otherUser.Token
}
if tt.expectSuccess {
var response map[string]bool
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
isUsing, exists := response["isUsing"]
assert.True(t, exists)
assert.False(t, isUsing)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "error")
}
// Cleanup
storages.RemoveTestStorage(storage.ID)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
return storages.CreateTestStorage(workspaceID)
}

View File

@@ -3,7 +3,7 @@ package backups_config
import (
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
)
var backupConfigRepository = &BackupConfigRepository{}
@@ -11,11 +11,11 @@ var backupConfigService = &BackupConfigService{
backupConfigRepository,
databases.GetDatabaseService(),
storages.GetStorageService(),
workspaces_services.GetWorkspaceService(),
nil,
}
var backupConfigController = &BackupConfigController{
backupConfigService,
users.GetUserService(),
}
func GetBackupConfigController() *BackupConfigController {

View File

@@ -1,10 +1,13 @@
package backups_config
import (
"errors"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/intervals"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/period"
"github.com/google/uuid"
@@ -14,6 +17,7 @@ type BackupConfigService struct {
backupConfigRepository *BackupConfigRepository
databaseService *databases.DatabaseService
storageService *storages.StorageService
workspaceService *workspaces_services.WorkspaceService
dbStorageChangeListener BackupConfigStorageChangeListener
}
@@ -32,11 +36,23 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
return nil, err
}
_, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
database, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot save backup config for database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to modify backup configuration")
}
return s.SaveBackupConfig(backupConfig)
}
@@ -66,19 +82,6 @@ func (s *BackupConfigService) SaveBackupConfig(
}
}
if !backupConfig.IsBackupsEnabled && existingConfig.StorageID != nil {
if err := s.dbStorageChangeListener.OnBeforeBackupsStorageChange(
backupConfig.DatabaseID,
); err != nil {
return nil, err
}
// we clear storage for disabled backups to allow
// storage removal for unused storages
backupConfig.Storage = nil
backupConfig.StorageID = nil
}
return s.backupConfigRepository.Save(backupConfig)
}
@@ -144,6 +147,10 @@ func (s *BackupConfigService) OnDatabaseCopied(originalDatabaseID, newDatabaseID
}
}
func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) error {
return s.initializeDefaultConfig(databaseID)
}
func (s *BackupConfigService) initializeDefaultConfig(
databaseID uuid.UUID,
) error {

View File

@@ -2,15 +2,18 @@ package databases
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type DatabaseController struct {
databaseService *DatabaseService
userService *users.UserService
databaseService *DatabaseService
userService *users_services.UserService
workspaceService *workspaces_services.WorkspaceService
}
func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
@@ -28,36 +31,35 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
// CreateDatabase
// @Summary Create a new database
// @Description Create a new database configuration
// @Description Create a new database configuration in a workspace
// @Tags databases
// @Accept json
// @Produce json
// @Param request body Database true "Database creation data"
// @Param request body Database true "Database creation data with workspaceId"
// @Success 201 {object} Database
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /databases/create [post]
func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
if request.WorkspaceID == nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
database, err := c.databaseService.CreateDatabase(user, &request)
database, err := c.databaseService.CreateDatabase(user, *request.WorkspaceID, &request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -79,24 +81,18 @@ func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
// @Failure 500
// @Router /databases/update [post]
func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.UpdateDatabase(user, &request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -116,24 +112,18 @@ func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id} [delete]
func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.DeleteDatabase(user, id); err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -153,24 +143,18 @@ func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
// @Failure 401
// @Router /databases/{id} [get]
func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
database, err := c.databaseService.GetDatabase(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -181,30 +165,38 @@ func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
}
// GetDatabases
// @Summary Get databases
// @Description Get all databases for the authenticated user
// @Summary Get databases by workspace
// @Description Get all databases for a specific workspace
// @Tags databases
// @Produce json
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Database
// @Failure 400
// @Failure 401
// @Failure 500
// @Router /databases [get]
func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
databases, err := c.databaseService.GetDatabasesByUser(user)
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
databases, err := c.databaseService.GetDatabasesByWorkspace(user, workspaceID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -222,24 +214,18 @@ func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id}/test-connection [post]
func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.databaseService.TestDatabaseConnection(user, id); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -259,27 +245,18 @@ func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
// @Failure 401
// @Router /databases/test-connection-direct [post]
func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
_, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Database
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
// Set user ID for validation purposes
request.UserID = user.ID
if err := c.databaseService.TestDatabaseConnectionDirect(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
@@ -300,24 +277,18 @@ func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
// @Failure 500
// @Router /databases/notifier/{id}/is-using [get]
func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
isUsing, err := c.databaseService.IsNotifierUsing(user, id)
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -339,24 +310,18 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
// @Failure 500
// @Router /databases/{id}/copy [post]
func (c *DatabaseController) CopyDatabase(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
id, err := uuid.Parse(ctx.Param("id"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
copiedDatabase, err := c.databaseService.CopyDatabase(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})

View File

@@ -0,0 +1,928 @@
package databases
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases/databases/postgresql"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
GetDatabaseController(),
)
return router
}
func Test_CreateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can create database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
{
name: "workspace member can create database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
{
name: "workspace viewer cannot create database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can create database",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
var response Database
testResp := test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, "Test Database", response.Name)
assert.NotEqual(t, uuid.Nil, response.ID)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CreateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testDbName := "test_db"
request := Database{
Name: "Test Database",
WorkspaceID: &workspace.ID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/create",
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_UpdateDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can update database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can update database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer cannot update database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can update database",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
database.Name = "Updated Database"
var response Database
testResp := test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/update",
"Bearer "+testUserToken,
database,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, "Updated Database", response.Name)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_UpdateDatabase_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
database.Name = "Hacked Name"
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/databases/update",
"Bearer "+nonMember.Token,
database,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_DeleteDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can delete database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace member can delete database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
{
name: "workspace viewer cannot delete database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusInternalServerError,
},
{
name: "global admin can delete database",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusNoContent,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
testResp := test_utils.MakeDeleteRequest(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
if !tt.expectSuccess {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_GetDatabase_PermissionsEnforced(t *testing.T) {
memberRole := users_enums.WorkspaceRoleViewer
tests := []struct {
name string
userRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace member can get database",
userRole: &memberRole,
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get database",
userRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can get database",
userRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.userRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.userRole, owner.Token, router)
testUser = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
var response Database
testResp := test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String(),
"Bearer "+testUser,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.Equal(t, database.ID, response.ID)
assert.Equal(t, "Test Database", response.Name)
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_GetDatabasesByWorkspace_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace member can list databases",
isMember: true,
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot list databases",
isMember: false,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can list databases",
isMember: false,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
if tt.expectSuccess {
var response []Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases?workspace_id="+workspace.ID.String(),
"Bearer "+testUser,
tt.expectedStatusCode,
&response,
)
assert.GreaterOrEqual(t, len(response), 2)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/databases?workspace_id="+workspace.ID.String(),
"Bearer "+testUser,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_GetDatabasesByWorkspace_WhenMultipleDatabasesExist_ReturnsCorrectCount(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
createTestDatabaseViaAPI("Database 1", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 2", workspace.ID, owner.Token, router)
createTestDatabaseViaAPI("Database 3", workspace.ID, owner.Token, router)
var response []Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases?workspace_id="+workspace.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, 3, len(response))
}
func Test_GetDatabasesByWorkspace_EnsuresCrossWorkspaceIsolation(t *testing.T) {
router := createTestRouter()
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
createTestDatabaseViaAPI("Workspace1 DB1", workspace1.ID, owner1.Token, router)
createTestDatabaseViaAPI("Workspace1 DB2", workspace1.ID, owner1.Token, router)
createTestDatabaseViaAPI("Workspace2 DB1", workspace2.ID, owner2.Token, router)
var workspace1Dbs []Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases?workspace_id="+workspace1.ID.String(),
"Bearer "+owner1.Token,
http.StatusOK,
&workspace1Dbs,
)
var workspace2Dbs []Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/databases?workspace_id="+workspace2.ID.String(),
"Bearer "+owner2.Token,
http.StatusOK,
&workspace2Dbs,
)
assert.Equal(t, 2, len(workspace1Dbs))
assert.Equal(t, 1, len(workspace2Dbs))
for _, db := range workspace1Dbs {
assert.Equal(t, workspace1.ID, *db.WorkspaceID)
}
for _, db := range workspace2Dbs {
assert.Equal(t, workspace2.ID, *db.WorkspaceID)
}
}
func Test_CopyDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can copy database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
{
name: "workspace member can copy database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
{
name: "workspace viewer cannot copy database",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can copy database",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusCreated,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
var response Database
testResp := test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/copy",
"Bearer "+testUserToken,
nil,
tt.expectedStatusCode,
&response,
)
if tt.expectSuccess {
assert.NotEqual(t, database.ID, response.ID)
assert.Contains(t, response.Name, "(Copy)")
} else {
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_CopyDatabase_CopyStaysInSameWorkspace(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/"+database.ID.String()+"/copy",
"Bearer "+owner.Token,
nil,
http.StatusCreated,
&response,
)
assert.NotEqual(t, database.ID, response.ID)
assert.Equal(t, "Test Database (Copy)", response.Name)
assert.Equal(t, workspace.ID, *response.WorkspaceID)
assert.Equal(t, database.Type, response.Type)
}
func Test_TestConnection_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
isMember bool
isGlobalAdmin bool
expectAccessGranted bool
expectedStatusCodeOnErr int
}{
{
name: "workspace member can test connection",
isMember: true,
isGlobalAdmin: false,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "non-member cannot test connection",
isMember: false,
isGlobalAdmin: false,
expectAccessGranted: false,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
{
name: "global admin can test connection",
isMember: false,
isGlobalAdmin: true,
expectAccessGranted: true,
expectedStatusCodeOnErr: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUser string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUser = admin.Token
} else if tt.isMember {
testUser = owner.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUser = nonMember.Token
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/"+database.ID.String()+"/test-connection",
"Bearer "+testUser,
nil,
)
body := w.Body.String()
if tt.expectAccessGranted {
assert.True(
t,
w.Code == http.StatusOK ||
(w.Code == http.StatusBadRequest && strings.Contains(body, "connect")),
"Expected 200 OK or 400 with connection error, got %d: %s",
w.Code,
body,
)
} else {
assert.Equal(t, tt.expectedStatusCodeOnErr, w.Code)
assert.Contains(t, body, "insufficient permissions")
}
})
}
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *Database {
testDbName := "test_db"
request := Database{
Name: name,
WorkspaceID: &workspaceID,
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
databaseType DatabaseType
createDatabase func(workspaceID uuid.UUID) *Database
updateDatabase func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database
verifySensitiveData func(t *testing.T, database *Database)
verifyHiddenData func(t *testing.T, database *Database)
}{
{
name: "PostgreSQL Database",
databaseType: DatabaseTypePostgres,
createDatabase: func(workspaceID uuid.UUID) *Database {
testDbName := "test_db"
return &Database{
WorkspaceID: &workspaceID,
Name: "Test PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "original-password-secret",
Database: &testDbName,
},
}
},
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
testDbName := "updated_test_db"
return &Database{
ID: databaseID,
WorkspaceID: &workspaceID,
Name: "Updated PostgreSQL Database",
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion17,
Host: "updated-host",
Port: 5433,
Username: "updated_user",
Password: "",
Database: &testDbName,
},
}
},
verifySensitiveData: func(t *testing.T, database *Database) {
assert.Equal(t, "original-password-secret", database.Postgresql.Password)
},
verifyHiddenData: func(t *testing.T, database *Database) {
assert.Equal(t, "", database.Postgresql.Password)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create database with sensitive data
initialDatabase := tc.createDatabase(workspace.ID)
var createdDatabase Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/create",
"Bearer "+owner.Token,
*initialDatabase,
http.StatusCreated,
&createdDatabase,
)
assert.NotEmpty(t, createdDatabase.ID)
assert.Equal(t, initialDatabase.Name, createdDatabase.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedDatabase Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", createdDatabase.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedDatabase,
)
tc.verifyHiddenData(t, &retrievedDatabase)
assert.Equal(t, initialDatabase.Name, retrievedDatabase.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedDatabase := tc.updateDatabase(workspace.ID, createdDatabase.ID)
var updateResponse Database
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/databases/update",
"Bearer "+owner.Token,
*updatedDatabase,
http.StatusOK,
&updateResponse,
)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &DatabaseRepository{}
databaseFromDB, err := repository.FindByID(createdDatabase.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, databaseFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedDatabase.Name, databaseFromDB.Name)
// Phase 5: Additional verification - Check via GET that data is still hidden
var finalRetrieved Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", createdDatabase.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
// Phase 6: Verify GetDatabasesByWorkspace also hides sensitive data
var workspaceDatabases []Database
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/databases?workspace_id=%s", workspace.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&workspaceDatabases,
)
var foundDatabase *Database
for i := range workspaceDatabases {
if workspaceDatabases[i].ID == createdDatabase.ID {
foundDatabase = &workspaceDatabases[i]
break
}
}
assert.NotNil(t, foundDatabase, "Database should be found in workspace databases list")
tc.verifyHiddenData(t, foundDatabase)
// Clean up: Delete database before removing workspace
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/databases/%s", createdDatabase.ID.String()),
"Bearer "+owner.Token,
http.StatusNoContent,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}

View File

@@ -66,6 +66,27 @@ func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
return testSingleDatabaseConnection(logger, ctx, p)
}
func (p *PostgresqlDatabase) HideSensitiveData() {
if p == nil {
return
}
p.Password = ""
}
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
p.Version = incoming.Version
p.Host = incoming.Host
p.Port = incoming.Port
p.Username = incoming.Username
p.Database = incoming.Database
p.IsHttps = incoming.IsHttps
if incoming.Password != "" {
p.Password = incoming.Password
}
}
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
func testSingleDatabaseConnection(
logger *slog.Logger,

View File

@@ -1,8 +1,10 @@
package databases
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/users"
users_services "postgresus-backend/internal/features/users/services"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -15,11 +17,14 @@ var databaseService = &DatabaseService{
[]DatabaseCreationListener{},
[]DatabaseRemoveListener{},
[]DatabaseCopyListener{},
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
}
var databaseController = &DatabaseController{
databaseService,
users.GetUserService(),
users_services.GetUserService(),
workspaces_services.GetWorkspaceService(),
}
func GetDatabaseService() *DatabaseService {
@@ -29,3 +34,7 @@ func GetDatabaseService() *DatabaseService {
func GetDatabaseController() *DatabaseController {
return databaseController
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
}

View File

@@ -12,6 +12,8 @@ type DatabaseValidator interface {
type DatabaseConnector interface {
TestConnection(logger *slog.Logger) error
HideSensitiveData()
}
type DatabaseCreationListener interface {

View File

@@ -11,10 +11,13 @@ import (
)
type Database struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;type:uuid;not null"`
Name string `json:"name" gorm:"column:name;type:text;not null"`
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
// WorkspaceID can be null when a database is created via restore operation
// outside the context of any workspace
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;type:uuid"`
Name string `json:"name" gorm:"column:name;type:text;not null"`
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
@@ -57,6 +60,23 @@ func (d *Database) TestConnection(logger *slog.Logger) error {
return d.getSpecificDatabase().TestConnection(logger)
}
func (d *Database) HideSensitiveData() {
d.getSpecificDatabase().HideSensitiveData()
}
func (d *Database) Update(incoming *Database) {
d.Name = incoming.Name
d.Type = incoming.Type
d.Notifiers = incoming.Notifiers
switch d.Type {
case DatabaseTypePostgres:
if d.Postgresql != nil && incoming.Postgresql != nil {
d.Postgresql.Update(incoming.Postgresql)
}
}
}
func (d *Database) getSpecificDatabase() DatabaseConnector {
switch d.Type {
case DatabaseTypePostgres:

View File

@@ -92,14 +92,14 @@ func (r *DatabaseRepository) FindByID(id uuid.UUID) (*Database, error) {
return &database, nil
}
func (r *DatabaseRepository) FindByUserID(userID uuid.UUID) ([]*Database, error) {
func (r *DatabaseRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Database, error) {
var databases []*Database
if err := storage.
GetDb().
Preload("Postgresql").
Preload("Notifiers").
Where("user_id = ?", userID).
Where("workspace_id = ?", workspaceID).
Order("CASE WHEN health_status = 'UNAVAILABLE' THEN 1 WHEN health_status = 'AVAILABLE' THEN 2 WHEN health_status IS NULL THEN 3 ELSE 4 END, name ASC").
Find(&databases).Error; err != nil {
return nil, err

View File

@@ -2,11 +2,15 @@ package databases
import (
"errors"
"fmt"
"log/slog"
"time"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/notifiers"
users_models "postgresus-backend/internal/features/users/models"
"time"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/google/uuid"
)
@@ -19,6 +23,9 @@ type DatabaseService struct {
dbCreationListener []DatabaseCreationListener
dbRemoveListener []DatabaseRemoveListener
dbCopyListener []DatabaseCopyListener
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
}
func (s *DatabaseService) AddDbCreationListener(
@@ -41,15 +48,24 @@ func (s *DatabaseService) AddDbCopyListener(
func (s *DatabaseService) CreateDatabase(
user *users_models.User,
workspaceID uuid.UUID,
database *Database,
) (*Database, error) {
database.UserID = user.ID
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to create database in this workspace")
}
database.WorkspaceID = &workspaceID
if err := database.Validate(); err != nil {
return nil, err
}
database, err := s.dbRepository.Save(database)
database, err = s.dbRepository.Save(database)
if err != nil {
return nil, err
}
@@ -58,6 +74,12 @@ func (s *DatabaseService) CreateDatabase(
listener.OnDatabaseCreated(database.ID)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database created: %s", database.Name),
&user.ID,
&workspaceID,
)
return database, nil
}
@@ -74,24 +96,39 @@ func (s *DatabaseService) UpdateDatabase(
return err
}
if existingDatabase.UserID != user.ID {
return errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return errors.New("cannot update database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to update this database")
}
// Validate the update
if err := database.ValidateUpdate(*existingDatabase, *database); err != nil {
return err
}
if err := database.Validate(); err != nil {
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return err
}
_, err = s.dbRepository.Save(database)
_, err = s.dbRepository.Save(existingDatabase)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return nil
}
@@ -104,8 +141,16 @@ func (s *DatabaseService) DeleteDatabase(
return err
}
if existingDatabase.UserID != user.ID {
return errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return errors.New("cannot delete database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to delete this database")
}
for _, listener := range s.dbRemoveListener {
@@ -114,6 +159,12 @@ func (s *DatabaseService) DeleteDatabase(
}
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database deleted: %s", existingDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return s.dbRepository.Delete(id)
}
@@ -126,17 +177,44 @@ func (s *DatabaseService) GetDatabase(
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("you have not access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot access database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access this database")
}
database.HideSensitiveData()
return database, nil
}
func (s *DatabaseService) GetDatabasesByUser(
func (s *DatabaseService) GetDatabasesByWorkspace(
user *users_models.User,
workspaceID uuid.UUID,
) ([]*Database, error) {
return s.dbRepository.FindByUserID(user.ID)
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access this workspace")
}
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return nil, err
}
for _, database := range databases {
database.HideSensitiveData()
}
return databases, nil
}
func (s *DatabaseService) IsNotifierUsing(
@@ -160,8 +238,16 @@ func (s *DatabaseService) TestDatabaseConnection(
return err
}
if database.UserID != user.ID {
return errors.New("you have not access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot test connection for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to test connection for this database")
}
err = database.TestConnection(s.logger)
@@ -184,7 +270,31 @@ func (s *DatabaseService) TestDatabaseConnection(
func (s *DatabaseService) TestDatabaseConnectionDirect(
database *Database,
) error {
return database.TestConnection(s.logger)
var usingDatabase *Database
if database.ID != uuid.Nil {
existingDatabase, err := s.dbRepository.FindByID(database.ID)
if err != nil {
return err
}
if database.WorkspaceID != nil && existingDatabase.WorkspaceID != nil &&
*existingDatabase.WorkspaceID != *database.WorkspaceID {
return errors.New("database does not belong to this workspace")
}
existingDatabase.Update(database)
if err := existingDatabase.Validate(); err != nil {
return err
}
usingDatabase = existingDatabase
} else {
usingDatabase = database
}
return usingDatabase.TestConnection(s.logger)
}
func (s *DatabaseService) GetDatabaseByID(
@@ -237,13 +347,21 @@ func (s *DatabaseService) CopyDatabase(
return nil, err
}
if existingDatabase.UserID != user.ID {
return nil, errors.New("you have not access to this database")
if existingDatabase.WorkspaceID == nil {
return nil, errors.New("cannot copy database without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canManage {
return nil, errors.New("insufficient permissions to copy this database")
}
newDatabase := &Database{
ID: uuid.Nil,
UserID: user.ID,
WorkspaceID: existingDatabase.WorkspaceID,
Name: existingDatabase.Name + " (Copy)",
Type: existingDatabase.Type,
Notifiers: existingDatabase.Notifiers,
@@ -286,6 +404,12 @@ func (s *DatabaseService) CopyDatabase(
listener.OnDatabaseCopied(databaseID, copiedDatabase.ID)
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Database copied: %s to %s", existingDatabase.Name, copiedDatabase.Name),
&user.ID,
existingDatabase.WorkspaceID,
)
return copiedDatabase, nil
}
@@ -306,3 +430,19 @@ func (s *DatabaseService) SetHealthStatus(
return nil
}
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return err
}
if len(databases) > 0 {
return fmt.Errorf(
"workspace contains %d databases that must be deleted",
len(databases),
)
}
return nil
}

View File

@@ -10,14 +10,14 @@ import (
)
func CreateTestDatabase(
userID uuid.UUID,
workspaceID uuid.UUID,
storage *storages.Storage,
notifier *notifiers.Notifier,
) *Database {
database := &Database{
UserID: userID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
WorkspaceID: &workspaceID,
Name: "test " + uuid.New().String(),
Type: DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,

View File

@@ -10,23 +10,34 @@ import (
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func Test_CheckPgHealthUseCase(t *testing.T) {
user := users.GetTestUser()
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
storage := storages.CreateTestStorage(user.UserID)
notifier := notifiers.CreateTestNotifier(user.UserID)
// Create workspace directly via service
workspace, err := workspaces_testing.CreateTestWorkspaceDirect("Test Workspace", user.UserID)
if err != nil {
t.Fatalf("Failed to create workspace: %v", err)
}
defer storages.RemoveTestStorage(storage.ID)
defer notifiers.RemoveTestNotifier(notifier)
storage := storages.CreateTestStorage(workspace.ID)
notifier := notifiers.CreateTestNotifier(workspace.ID)
defer func() {
storages.RemoveTestStorage(storage.ID)
notifiers.RemoveTestNotifier(notifier)
workspaces_testing.RemoveTestWorkspaceDirect(workspace.ID)
}()
t.Run("Test_DbAttemptFailed_DbMarkedAsUnavailable", func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender
@@ -94,7 +105,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbNotMarkerdAsDownAfterFirstAttempt",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender
@@ -160,7 +171,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbMarkerdAsDownAfterThirdFailedAttempt",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Make sure DB is available
@@ -237,7 +248,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
)
t.Run("Test_UnavailableDbAttemptSucceed_DbMarkedAsAvailable", func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Make sure DB is unavailable
@@ -311,7 +322,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
t.Run(
"Test_DbHealthcheckExecutedFast_HealthcheckNotExecutedFasterThanInterval",
func(t *testing.T) {
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
defer databases.RemoveTestDatabase(database)
// Setup mock notifier sender

View File

@@ -2,7 +2,7 @@ package healthcheck_attempt
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"time"
"github.com/gin-gonic/gin"
@@ -11,7 +11,6 @@ import (
type HealthcheckAttemptController struct {
healthcheckAttemptService *HealthcheckAttemptService
userService *users.UserService
}
func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
@@ -31,9 +30,9 @@ func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /healthcheck-attempts/{databaseId} [get]
func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -43,7 +42,7 @@ func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
return
}
afterDate := time.Now().UTC()
afterDate := time.Now().UTC().Add(-7 * 24 * time.Hour)
if afterDateStr := ctx.Query("afterDate"); afterDateStr != "" {
parsedDate, err := time.Parse(time.RFC3339, afterDateStr)
if err != nil {

View File

@@ -0,0 +1,261 @@
package healthcheck_attempt
import (
"encoding/json"
"fmt"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetHealthcheckAttemptController(),
)
return router
}
func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get healthcheck attempts",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get healthcheck attempts",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get healthcheck attempts",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
pastTime := time.Now().UTC().Add(-1 * time.Hour)
createTestHealthcheckAttemptWithTime(
database.ID,
databases.HealthStatusAvailable,
pastTime,
)
createTestHealthcheckAttemptWithTime(
database.ID,
databases.HealthStatusUnavailable,
pastTime.Add(-30*time.Minute),
)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
if tt.expectSuccess {
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
assert.GreaterOrEqual(t, len(response), 2)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "forbidden")
}
})
}
}
func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
oldTime := time.Now().UTC().Add(-2 * time.Hour)
recentTime := time.Now().UTC().Add(-30 * time.Minute)
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusAvailable, oldTime)
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusUnavailable, recentTime)
createTestHealthcheckAttempt(database.ID, databases.HealthStatusAvailable)
afterDate := time.Now().UTC().Add(-1 * time.Hour)
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf(
"/api/v1/healthcheck-attempts/%s?afterDate=%s",
database.ID.String(),
afterDate.Format(time.RFC3339),
),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, 2, len(response))
for _, attempt := range response {
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
}
}
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response []*HealthcheckAttempt
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-attempts/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, 0, len(response))
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestHealthcheckAttempt(databaseID uuid.UUID, status databases.HealthStatus) {
createTestHealthcheckAttemptWithTime(databaseID, status, time.Now().UTC())
}
func createTestHealthcheckAttemptWithTime(
databaseID uuid.UUID,
status databases.HealthStatus,
createdAt time.Time,
) {
repo := GetHealthcheckAttemptRepository()
attempt := &HealthcheckAttempt{
ID: uuid.New(),
DatabaseID: databaseID,
Status: status,
CreatedAt: createdAt,
}
if err := repo.Create(attempt); err != nil {
panic("Failed to create test healthcheck attempt: " + err.Error())
}
}

View File

@@ -4,7 +4,7 @@ import (
"postgresus-backend/internal/features/databases"
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
"postgresus-backend/internal/features/notifiers"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -12,6 +12,7 @@ var healthcheckAttemptRepository = &HealthcheckAttemptRepository{}
var healthcheckAttemptService = &HealthcheckAttemptService{
healthcheckAttemptRepository,
databases.GetDatabaseService(),
workspaces_services.GetWorkspaceService(),
}
var checkPgHealthUseCase = &CheckPgHealthUseCase{
@@ -27,7 +28,10 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
}
var healthcheckAttemptController = &HealthcheckAttemptController{
healthcheckAttemptService,
users.GetUserService(),
}
func GetHealthcheckAttemptRepository() *HealthcheckAttemptRepository {
return healthcheckAttemptRepository
}
func GetHealthcheckAttemptService() *HealthcheckAttemptService {

View File

@@ -53,7 +53,7 @@ func (r *HealthcheckAttemptRepository) DeleteOlderThan(
Delete(&HealthcheckAttempt{}).Error
}
func (r *HealthcheckAttemptRepository) Insert(
func (r *HealthcheckAttemptRepository) Create(
attempt *HealthcheckAttempt,
) error {
if attempt.ID == uuid.Nil {
@@ -67,6 +67,12 @@ func (r *HealthcheckAttemptRepository) Insert(
return storage.GetDb().Create(attempt).Error
}
func (r *HealthcheckAttemptRepository) Insert(
attempt *HealthcheckAttempt,
) error {
return r.Create(attempt)
}
func (r *HealthcheckAttemptRepository) FindByDatabaseIDWithLimit(
databaseID uuid.UUID,
limit int,

View File

@@ -4,6 +4,7 @@ import (
"errors"
"postgresus-backend/internal/features/databases"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"time"
"github.com/google/uuid"
@@ -12,6 +13,7 @@ import (
type HealthcheckAttemptService struct {
healthcheckAttemptRepository *HealthcheckAttemptRepository
databaseService *databases.DatabaseService
workspaceService *workspaces_services.WorkspaceService
}
func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
@@ -24,7 +26,15 @@ func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
return nil, err
}
if database.UserID != user.ID {
if database.WorkspaceID == nil {
return nil, errors.New("cannot access healthcheck attempts for databases without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("forbidden")
}

View File

@@ -2,7 +2,7 @@ package healthcheck_config
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type HealthcheckConfigController struct {
healthcheckConfigService *HealthcheckConfigService
userService *users.UserService
}
func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
@@ -31,9 +30,9 @@ func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /healthcheck-config [post]
func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -65,9 +64,9 @@ func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
// @Failure 401
// @Router /healthcheck-config/{databaseId} [get]
func (c *HealthcheckConfigController) GetHealthcheckConfig(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}

View File

@@ -0,0 +1,328 @@
package healthcheck_config
import (
"encoding/json"
"net/http"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
users_enums "postgresus-backend/internal/features/users/enums"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
GetHealthcheckConfigController(),
)
return router
}
func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer cannot save healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
{
name: "global admin can save healthcheck config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
}
request := HealthcheckConfigDTO{
DatabaseID: database.ID,
IsHealthcheckEnabled: true,
IsSentNotificationWhenUnavailable: true,
IntervalMinutes: 5,
AttemptsBeforeConcideredAsDown: 3,
StoreAttemptsDays: 7,
}
if tt.expectSuccess {
var response map[string]string
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
&response,
)
assert.Contains(t, response["message"], "successfully")
} else {
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+testUserToken,
request,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := HealthcheckConfigDTO{
DatabaseID: database.ID,
IsHealthcheckEnabled: true,
IsSentNotificationWhenUnavailable: true,
IntervalMinutes: 5,
AttemptsBeforeConcideredAsDown: 3,
StoreAttemptsDays: 7,
}
testResp := test_utils.MakePostRequest(
t,
router,
"/api/v1/healthcheck-config",
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
tests := []struct {
name string
workspaceRole *users_enums.WorkspaceRole
isGlobalAdmin bool
expectSuccess bool
expectedStatusCode int
}{
{
name: "workspace owner can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace admin can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace member can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "workspace viewer can get healthcheck config",
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
isGlobalAdmin: false,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "global admin can get healthcheck config",
workspaceRole: nil,
isGlobalAdmin: true,
expectSuccess: true,
expectedStatusCode: http.StatusOK,
},
{
name: "non-member cannot get healthcheck config",
workspaceRole: nil,
isGlobalAdmin: false,
expectSuccess: false,
expectedStatusCode: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var testUserToken string
if tt.isGlobalAdmin {
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
testUserToken = admin.Token
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
testUserToken = owner.Token
} else if tt.workspaceRole != nil {
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
testUserToken = member.Token
} else {
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testUserToken = nonMember.Token
}
if tt.expectSuccess {
var response HealthcheckConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsHealthcheckEnabled)
} else {
testResp := test_utils.MakeGetRequest(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+testUserToken,
tt.expectedStatusCode,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
})
}
}
func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
var response HealthcheckConfig
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/healthcheck-config/"+database.ID.String(),
"Bearer "+owner.Token,
http.StatusOK,
&response,
)
assert.Equal(t, database.ID, response.DatabaseID)
assert.True(t, response.IsHealthcheckEnabled)
assert.True(t, response.IsSentNotificationWhenUnavailable)
assert.Equal(t, 1, response.IntervalMinutes)
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
assert.Equal(t, 7, response.StoreAttemptsDays)
}
func createTestDatabaseViaAPI(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic("Failed to create database")
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}

View File

@@ -1,8 +1,9 @@
package healthcheck_config
import (
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -10,11 +11,12 @@ var healthcheckConfigRepository = &HealthcheckConfigRepository{}
var healthcheckConfigService = &HealthcheckConfigService{
databases.GetDatabaseService(),
healthcheckConfigRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
logger.GetLogger(),
}
var healthcheckConfigController = &HealthcheckConfigController{
healthcheckConfigService,
users.GetUserService(),
}
func GetHealthcheckConfigService() *HealthcheckConfigService {

View File

@@ -2,9 +2,12 @@ package healthcheck_config
import (
"errors"
"fmt"
"log/slog"
"postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/databases"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/google/uuid"
)
@@ -12,6 +15,8 @@ import (
type HealthcheckConfigService struct {
databaseService *databases.DatabaseService
healthcheckConfigRepository *HealthcheckConfigRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
logger *slog.Logger
}
@@ -33,8 +38,16 @@ func (s *HealthcheckConfigService) Save(
return err
}
if database.UserID != user.ID {
return errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return errors.New("cannot modify healthcheck config for databases without workspace")
}
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, &user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to modify healthcheck config")
}
healthcheckConfig := configDTO.ToDTO()
@@ -60,6 +73,12 @@ func (s *HealthcheckConfigService) Save(
}
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Healthcheck config updated for database '%s'", database.Name),
&user.ID,
database.WorkspaceID,
)
return nil
}
@@ -72,8 +91,16 @@ func (s *HealthcheckConfigService) GetByDatabaseID(
return nil, err
}
if database.UserID != user.ID {
return nil, errors.New("user does not have access to this database")
if database.WorkspaceID == nil {
return nil, errors.New("cannot access healthcheck config for databases without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to view healthcheck config")
}
config, err := s.healthcheckConfigRepository.GetByDatabaseID(database.ID)

View File

@@ -2,15 +2,16 @@ package notifiers
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type NotifierController struct {
notifierService *NotifierService
userService *users.UserService
notifierService *NotifierService
workspaceService *workspaces_services.WorkspaceService
}
func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,35 +30,40 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param notifier body Notifier true "Notifier data"
// @Param request body Notifier true "Notifier data with workspaceId"
// @Success 200 {object} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers [post]
func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var notifier Notifier
if err := ctx.ShouldBindJSON(&notifier); err != nil {
var request Notifier
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := notifier.Validate(); err != nil {
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.notifierService.SaveNotifier(user, &notifier); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, notifier)
ctx.JSON(http.StatusOK, request)
}
// GetNotifier
@@ -70,11 +76,12 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
// @Success 200 {object} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id} [get]
func (c *NotifierController) GetNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -86,6 +93,10 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
notifier, err := c.notifierService.GetNotifier(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -95,22 +106,41 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
// GetNotifiers
// @Summary Get all notifiers
// @Description Get all notifiers for the current user
// @Description Get all notifiers for a workspace
// @Tags notifiers
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Notifier
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers [get]
func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
notifiers, err := c.notifierService.GetNotifiers(user)
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -128,11 +158,12 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id} [delete]
func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -142,13 +173,11 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
return
}
notifier, err := c.notifierService.GetNotifier(user, id)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.notifierService.DeleteNotifier(user, notifier.ID); err != nil {
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -166,11 +195,12 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/{id}/test [post]
func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -181,6 +211,10 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
}
if err := c.notifierService.SendTestNotification(user, id); err != nil {
if err.Error() == "insufficient permissions to test notifier in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -195,28 +229,44 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param notifier body Notifier true "Notifier data"
// @Param request body Notifier true "Notifier data with workspaceId"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /notifiers/direct-test [post]
func (c *NotifierController) SendTestNotificationDirect(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var notifier Notifier
if err := ctx.ShouldBindJSON(&notifier); err != nil {
var request Notifier
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// For direct test, associate with the current user
notifier.UserID = user.ID
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.notifierService.SendTestNotificationToNotifier(&notifier); err != nil {
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !canView {
ctx.JSON(
http.StatusForbidden,
gin.H{"error": "insufficient permissions to test notifier in this workspace"},
)
return
}
if err := c.notifierService.SendTestNotificationToNotifier(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

View File

@@ -0,0 +1,815 @@
package notifiers
import (
"fmt"
"net/http"
"testing"
"postgresus-backend/internal/config"
audit_logs "postgresus-backend/internal/features/audit_logs"
discord_notifier "postgresus-backend/internal/features/notifiers/models/discord"
email_notifier "postgresus-backend/internal/features/notifiers/models/email_notifier"
slack_notifier "postgresus-backend/internal/features/notifiers/models/slack"
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_SaveNewNotifier_NotifierReturnedViaGet(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
verifyNotifierData(t, notifier, &savedNotifier)
assert.NotEmpty(t, savedNotifier.ID)
// Verify notifier is returned via GET
var retrievedNotifier Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedNotifier,
)
verifyNotifierData(t, &savedNotifier, &retrievedNotifier)
// Verify notifier is returned via GET all notifiers
var notifiers []Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&notifiers,
)
assert.Len(t, notifiers, 1)
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UpdateExistingNotifier_UpdatedNotifierReturnedViaGet(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
updatedName := "Updated Notifier " + uuid.New().String()
savedNotifier.Name = updatedName
var updatedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
savedNotifier,
http.StatusOK,
&updatedNotifier,
)
assert.Equal(t, updatedName, updatedNotifier.Name)
assert.Equal(t, savedNotifier.ID, updatedNotifier.ID)
deleteNotifier(t, router, updatedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteNotifier_NotifierNotReturnedViaGet(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
response := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+owner.Token,
http.StatusBadRequest,
)
assert.Contains(t, string(response.Body), "error")
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_SendTestNotificationDirect_NotificationSent(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createTelegramNotifier(workspace.ID)
response := test_utils.MakePostRequest(
t, router, "/api/v1/notifiers/direct-test", "Bearer "+owner.Token, *notifier, http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_SendTestNotificationExisting_NotificationSent(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createTelegramNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
response := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s/test", savedNotifier.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_ViewerCanViewNotifiers_ButCannotModify(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
viewer,
users_enums.WorkspaceRoleViewer,
owner.Token,
router,
)
notifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
// Viewer can GET notifiers
var notifiers []Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
"Bearer "+viewer.Token,
http.StatusOK,
&notifiers,
)
assert.Len(t, notifiers, 1)
// Viewer cannot CREATE notifier
newNotifier := createNewNotifier(workspace.ID)
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, *newNotifier, http.StatusForbidden,
)
// Viewer cannot UPDATE notifier
savedNotifier.Name = "Updated by viewer"
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+viewer.Token, savedNotifier, http.StatusForbidden,
)
// Viewer cannot DELETE notifier
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+viewer.Token,
http.StatusForbidden,
)
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_MemberCanManageNotifiers(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
notifier := createNewNotifier(workspace.ID)
// Member can CREATE notifier
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+member.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
assert.NotEmpty(t, savedNotifier.ID)
// Member can UPDATE notifier
savedNotifier.Name = "Updated by member"
var updatedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+member.Token,
savedNotifier,
http.StatusOK,
&updatedNotifier,
)
assert.Equal(t, "Updated by member", updatedNotifier.Name)
// Member can DELETE notifier
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_AdminCanManageNotifiers(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
admin,
users_enums.WorkspaceRoleAdmin,
owner.Token,
router,
)
notifier := createNewNotifier(workspace.ID)
// Admin can CREATE, UPDATE, DELETE
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+admin.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
savedNotifier.Name = "Updated by admin"
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+admin.Token, savedNotifier, http.StatusOK,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UserNotInWorkspace_CannotAccessNotifiers(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
notifier := createNewNotifier(workspace.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*notifier,
http.StatusOK,
&savedNotifier,
)
// Outsider cannot GET notifiers
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers?workspace_id=%s", workspace.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
// Outsider cannot CREATE notifier
test_utils.MakePostRequest(
t, router, "/api/v1/notifiers", "Bearer "+outsider.Token, *notifier, http.StatusForbidden,
)
// Outsider cannot UPDATE notifier
test_utils.MakePostRequest(
t,
router,
"/api/v1/notifiers",
"Bearer "+outsider.Token,
savedNotifier,
http.StatusForbidden,
)
// Outsider cannot DELETE notifier
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
deleteNotifier(t, router, savedNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CrossWorkspaceSecurity_CannotAccessNotifierFromAnotherWorkspace(t *testing.T) {
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
notifier1 := createNewNotifier(workspace1.ID)
var savedNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner1.Token,
*notifier1,
http.StatusOK,
&savedNotifier,
)
// Try to access workspace1's notifier with owner2 from workspace2
response := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", savedNotifier.ID.String()),
"Bearer "+owner2.Token,
http.StatusForbidden,
)
assert.Contains(t, string(response.Body), "insufficient permissions")
deleteNotifier(t, router, savedNotifier.ID, workspace1.ID, owner1.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetNotifierController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewNotifier(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Notifier " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.site/test-" + uuid.New().String(),
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
}
func createTelegramNotifier(workspaceID uuid.UUID) *Notifier {
env := config.GetEnv()
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier " + uuid.New().String(),
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: env.TestTelegramBotToken,
TargetChatID: env.TestTelegramChatID,
},
}
}
func verifyNotifierData(t *testing.T, expected *Notifier, actual *Notifier) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.NotifierType, actual.NotifierType)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteNotifier(
t *testing.T,
router *gin.Engine,
notifierID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", notifierID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func Test_NotifierSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
notifierType NotifierType
createNotifier func(workspaceID uuid.UUID) *Notifier
updateNotifier func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier
verifySensitiveData func(t *testing.T, notifier *Notifier)
verifyHiddenData func(t *testing.T, notifier *Notifier)
}{
{
name: "Telegram Notifier",
notifierType: NotifierTypeTelegram,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Telegram Notifier",
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: "original-bot-token-12345",
TargetChatID: "123456789",
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Telegram Notifier",
NotifierType: NotifierTypeTelegram,
TelegramNotifier: &telegram_notifier.TelegramNotifier{
BotToken: "",
TargetChatID: "987654321",
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-bot-token-12345", notifier.TelegramNotifier.BotToken)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.TelegramNotifier.BotToken)
},
},
{
name: "Email Notifier",
notifierType: NotifierTypeEmail,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Email Notifier",
NotifierType: NotifierTypeEmail,
EmailNotifier: &email_notifier.EmailNotifier{
TargetEmail: "test@example.com",
SMTPHost: "smtp.example.com",
SMTPPort: 587,
SMTPUser: "user@example.com",
SMTPPassword: "original-password-secret",
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Email Notifier",
NotifierType: NotifierTypeEmail,
EmailNotifier: &email_notifier.EmailNotifier{
TargetEmail: "updated@example.com",
SMTPHost: "smtp.newhost.com",
SMTPPort: 465,
SMTPUser: "newuser@example.com",
SMTPPassword: "",
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "original-password-secret", notifier.EmailNotifier.SMTPPassword)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.EmailNotifier.SMTPPassword)
},
},
{
name: "Slack Notifier",
notifierType: NotifierTypeSlack,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Slack Notifier",
NotifierType: NotifierTypeSlack,
SlackNotifier: &slack_notifier.SlackNotifier{
BotToken: "xoxb-original-slack-token",
TargetChatID: "C123456",
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Slack Notifier",
NotifierType: NotifierTypeSlack,
SlackNotifier: &slack_notifier.SlackNotifier{
BotToken: "",
TargetChatID: "C789012",
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "xoxb-original-slack-token", notifier.SlackNotifier.BotToken)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.SlackNotifier.BotToken)
},
},
{
name: "Discord Notifier",
notifierType: NotifierTypeDiscord,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Discord Notifier",
NotifierType: NotifierTypeDiscord,
DiscordNotifier: &discord_notifier.DiscordNotifier{
ChannelWebhookURL: "https://discord.com/api/webhooks/123/original-token",
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Discord Notifier",
NotifierType: NotifierTypeDiscord,
DiscordNotifier: &discord_notifier.DiscordNotifier{
ChannelWebhookURL: "",
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(
t,
"https://discord.com/api/webhooks/123/original-token",
notifier.DiscordNotifier.ChannelWebhookURL,
)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.DiscordNotifier.ChannelWebhookURL)
},
},
{
name: "Teams Notifier",
notifierType: NotifierTypeTeams,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Teams Notifier",
NotifierType: NotifierTypeTeams,
TeamsNotifier: &teams_notifier.TeamsNotifier{
WebhookURL: "https://outlook.office.com/webhook/original-token",
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Teams Notifier",
NotifierType: NotifierTypeTeams,
TeamsNotifier: &teams_notifier.TeamsNotifier{
WebhookURL: "",
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
assert.Equal(
t,
"https://outlook.office.com/webhook/original-token",
notifier.TeamsNotifier.WebhookURL,
)
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
assert.Equal(t, "", notifier.TeamsNotifier.WebhookURL)
},
},
{
name: "Webhook Notifier",
notifierType: NotifierTypeWebhook,
createNotifier: func(workspaceID uuid.UUID) *Notifier {
return &Notifier{
WorkspaceID: workspaceID,
Name: "Test Webhook Notifier",
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/test",
WebhookMethod: webhook_notifier.WebhookMethodPOST,
},
}
},
updateNotifier: func(workspaceID uuid.UUID, notifierID uuid.UUID) *Notifier {
return &Notifier{
ID: notifierID,
WorkspaceID: workspaceID,
Name: "Updated Webhook Notifier",
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{
WebhookURL: "https://webhook.example.com/updated",
WebhookMethod: webhook_notifier.WebhookMethodGET,
},
}
},
verifySensitiveData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to verify for webhook
},
verifyHiddenData: func(t *testing.T, notifier *Notifier) {
// No sensitive data to hide for webhook
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create notifier with sensitive data
initialNotifier := tc.createNotifier(workspace.ID)
var createdNotifier Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*initialNotifier,
http.StatusOK,
&createdNotifier,
)
assert.NotEmpty(t, createdNotifier.ID)
assert.Equal(t, initialNotifier.Name, createdNotifier.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedNotifier Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", createdNotifier.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedNotifier,
)
tc.verifyHiddenData(t, &retrievedNotifier)
assert.Equal(t, initialNotifier.Name, retrievedNotifier.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedNotifier := tc.updateNotifier(workspace.ID, createdNotifier.ID)
var updateResponse Notifier
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/notifiers",
"Bearer "+owner.Token,
*updatedNotifier,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedNotifier.Name, updateResponse.Name)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &NotifierRepository{}
notifierFromDB, err := repository.FindByID(createdNotifier.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, notifierFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedNotifier.Name, notifierFromDB.Name)
// Phase 5: Additional verification - Check via GET that data is still hidden
var finalRetrieved Notifier
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/notifiers/%s", createdNotifier.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
deleteNotifier(t, router, createdNotifier.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
})
}
}

View File

@@ -1,7 +1,8 @@
package notifiers
import (
"postgresus-backend/internal/features/users"
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -9,10 +10,12 @@ var notifierRepository = &NotifierRepository{}
var notifierService = &NotifierService{
notifierRepository,
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
}
var notifierController = &NotifierController{
notifierService,
users.GetUserService(),
workspaces_services.GetWorkspaceService(),
}
func GetNotifierController() *NotifierController {
@@ -22,3 +25,10 @@ func GetNotifierController() *NotifierController {
func GetNotifierService() *NotifierService {
return notifierService
}
func GetNotifierRepository() *NotifierRepository {
return notifierRepository
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
}

View File

@@ -6,4 +6,6 @@ type NotificationSender interface {
Send(logger *slog.Logger, heading string, message string) error
Validate() error
HideSensitiveData()
}

View File

@@ -15,7 +15,7 @@ import (
type Notifier struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null;type:uuid;index"`
WorkspaceID uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;not null;type:uuid;index"`
Name string `json:"name" gorm:"column:name;not null;type:varchar(255)"`
NotifierType NotifierType `json:"notifierType" gorm:"column:notifier_type;not null;type:varchar(50)"`
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
@@ -54,6 +54,42 @@ func (n *Notifier) Send(logger *slog.Logger, heading string, message string) err
return err
}
func (n *Notifier) HideSensitiveData() {
n.getSpecificNotifier().HideSensitiveData()
}
func (n *Notifier) Update(incoming *Notifier) {
n.Name = incoming.Name
n.NotifierType = incoming.NotifierType
switch n.NotifierType {
case NotifierTypeTelegram:
if n.TelegramNotifier != nil && incoming.TelegramNotifier != nil {
n.TelegramNotifier.Update(incoming.TelegramNotifier)
}
case NotifierTypeEmail:
if n.EmailNotifier != nil && incoming.EmailNotifier != nil {
n.EmailNotifier.Update(incoming.EmailNotifier)
}
case NotifierTypeWebhook:
if n.WebhookNotifier != nil && incoming.WebhookNotifier != nil {
n.WebhookNotifier.Update(incoming.WebhookNotifier)
}
case NotifierTypeSlack:
if n.SlackNotifier != nil && incoming.SlackNotifier != nil {
n.SlackNotifier.Update(incoming.SlackNotifier)
}
case NotifierTypeDiscord:
if n.DiscordNotifier != nil && incoming.DiscordNotifier != nil {
n.DiscordNotifier.Update(incoming.DiscordNotifier)
}
case NotifierTypeTeams:
if n.TeamsNotifier != nil && incoming.TeamsNotifier != nil {
n.TeamsNotifier.Update(incoming.TeamsNotifier)
}
}
}
func (n *Notifier) getSpecificNotifier() NotificationSender {
switch n.NotifierType {
case NotifierTypeTelegram:

View File

@@ -71,3 +71,13 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
return nil
}
func (d *DiscordNotifier) HideSensitiveData() {
d.ChannelWebhookURL = ""
}
func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
if incoming.ChannelWebhookURL != "" {
d.ChannelWebhookURL = incoming.ChannelWebhookURL
}
}

View File

@@ -27,6 +27,7 @@ type EmailNotifier struct {
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
}
func (e *EmailNotifier) TableName() string {
@@ -56,9 +57,12 @@ func (e *EmailNotifier) Validate() error {
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
// Compose email
from := e.SMTPUser
from := e.From
if from == "" {
from = "noreply@" + e.SMTPHost
from = e.SMTPUser
if from == "" {
from = "noreply@" + e.SMTPHost
}
}
to := []string{e.TargetEmail}
@@ -208,3 +212,19 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
return client.Quit()
}
}
func (e *EmailNotifier) HideSensitiveData() {
e.SMTPPassword = ""
}
func (e *EmailNotifier) Update(incoming *EmailNotifier) {
e.TargetEmail = incoming.TargetEmail
e.SMTPHost = incoming.SMTPHost
e.SMTPPort = incoming.SMTPPort
e.SMTPUser = incoming.SMTPUser
e.From = incoming.From
if incoming.SMTPPassword != "" {
e.SMTPPassword = incoming.SMTPPassword
}
}

View File

@@ -132,3 +132,15 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
return nil
}
}
func (s *SlackNotifier) HideSensitiveData() {
s.BotToken = ""
}
func (s *SlackNotifier) Update(incoming *SlackNotifier) {
s.TargetChatID = incoming.TargetChatID
if incoming.BotToken != "" {
s.BotToken = incoming.BotToken
}
}

View File

@@ -94,3 +94,13 @@ func (n *TeamsNotifier) Send(logger *slog.Logger, heading, message string) error
return nil
}
func (n *TeamsNotifier) HideSensitiveData() {
n.WebhookURL = ""
}
func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
if incoming.WebhookURL != "" {
n.WebhookURL = incoming.WebhookURL
}
}

View File

@@ -80,3 +80,16 @@ func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message str
return nil
}
func (t *TelegramNotifier) HideSensitiveData() {
t.BotToken = ""
}
func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
t.TargetChatID = incoming.TargetChatID
t.ThreadID = incoming.ThreadID
if incoming.BotToken != "" {
t.BotToken = incoming.BotToken
}
}

View File

@@ -102,3 +102,11 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
}
}
func (t *WebhookNotifier) HideSensitiveData() {
}
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
t.WebhookURL = incoming.WebhookURL
t.WebhookMethod = incoming.WebhookMethod
}

View File

@@ -143,7 +143,7 @@ func (r *NotifierRepository) FindByID(id uuid.UUID) (*Notifier, error) {
return &notifier, nil
}
func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error) {
func (r *NotifierRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Notifier, error) {
var notifiers []*Notifier
if err := storage.
@@ -154,7 +154,7 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
Preload("SlackNotifier").
Preload("DiscordNotifier").
Preload("TeamsNotifier").
Where("user_id = ?", userID).
Where("workspace_id = ?", workspaceID).
Order("name ASC").
Find(&notifiers).Error; err != nil {
return nil, err
@@ -165,7 +165,6 @@ func (r *NotifierRepository) FindByUserID(userID uuid.UUID) ([]*Notifier, error)
func (r *NotifierRepository) Delete(notifier *Notifier) error {
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
switch notifier.NotifierType {
case NotifierTypeTelegram:
if notifier.TelegramNotifier != nil {

View File

@@ -2,8 +2,12 @@ package notifiers
import (
"errors"
"fmt"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/google/uuid"
)
@@ -11,30 +15,68 @@ import (
type NotifierService struct {
notifierRepository *NotifierRepository
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
}
func (s *NotifierService) SaveNotifier(
user *users_models.User,
workspaceID uuid.UUID,
notifier *Notifier,
) error {
if notifier.ID != uuid.Nil {
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
}
isUpdate := notifier.ID != uuid.Nil
if isUpdate {
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
if err != nil {
return err
}
if existingNotifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
if existingNotifier.WorkspaceID != workspaceID {
return errors.New("notifier does not belong to this workspace")
}
notifier.UserID = existingNotifier.UserID
} else {
notifier.UserID = user.ID
}
existingNotifier.Update(notifier)
_, err := s.notifierRepository.Save(notifier)
if err != nil {
return err
if err := existingNotifier.Validate(); err != nil {
return err
}
_, err = s.notifierRepository.Save(existingNotifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier updated: %s", existingNotifier.Name),
&user.ID,
&workspaceID,
)
} else {
notifier.WorkspaceID = workspaceID
if err := notifier.Validate(); err != nil {
return err
}
_, err = s.notifierRepository.Save(notifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier created: %s", notifier.Name),
&user.ID,
&workspaceID,
)
}
return nil
@@ -49,11 +91,26 @@ func (s *NotifierService) DeleteNotifier(
return err
}
if notifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
canManage, err := s.workspaceService.CanUserManageDBs(notifier.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage notifier in this workspace")
}
return s.notifierRepository.Delete(notifier)
err = s.notifierRepository.Delete(notifier)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Notifier deleted: %s", notifier.Name),
&user.ID,
&notifier.WorkspaceID,
)
return nil
}
func (s *NotifierService) GetNotifier(
@@ -65,17 +122,40 @@ func (s *NotifierService) GetNotifier(
return nil, err
}
if notifier.UserID != user.ID {
return nil, errors.New("you have not access to this notifier")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifier in this workspace")
}
notifier.HideSensitiveData()
return notifier, nil
}
func (s *NotifierService) GetNotifiers(
user *users_models.User,
workspaceID uuid.UUID,
) ([]*Notifier, error) {
return s.notifierRepository.FindByUserID(user.ID)
canView, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view notifiers in this workspace")
}
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return nil, err
}
for _, notifier := range notifiers {
notifier.HideSensitiveData()
}
return notifiers, nil
}
func (s *NotifierService) SendTestNotification(
@@ -87,8 +167,12 @@ func (s *NotifierService) SendTestNotification(
return err
}
if notifier.UserID != user.ID {
return errors.New("you have not access to this notifier")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(notifier.WorkspaceID, user)
if err != nil {
return err
}
if !canView {
return errors.New("insufficient permissions to test notifier in this workspace")
}
err = notifier.Send(s.logger, "Test message", "This is a test message")
@@ -107,7 +191,30 @@ func (s *NotifierService) SendTestNotification(
func (s *NotifierService) SendTestNotificationToNotifier(
notifier *Notifier,
) error {
return notifier.Send(s.logger, "Test message", "This is a test message")
var usingNotifier *Notifier
if notifier.ID != uuid.Nil {
existingNotifier, err := s.notifierRepository.FindByID(notifier.ID)
if err != nil {
return err
}
if existingNotifier.WorkspaceID != notifier.WorkspaceID {
return errors.New("notifier does not belong to this workspace")
}
existingNotifier.Update(notifier)
if err := existingNotifier.Validate(); err != nil {
return err
}
usingNotifier = existingNotifier
} else {
usingNotifier = notifier
}
return usingNotifier.Send(s.logger, "Test message", "This is a test message")
}
func (s *NotifierService) SendNotification(
@@ -143,3 +250,18 @@ func (s *NotifierService) SendNotification(
s.logger.Error("Failed to save notifier", "error", err)
}
}
func (s *NotifierService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
notifiers, err := s.notifierRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return fmt.Errorf("failed to get notifiers for workspace deletion: %w", err)
}
for _, notifier := range notifiers {
if err := s.notifierRepository.Delete(notifier); err != nil {
return fmt.Errorf("failed to delete notifier %s: %w", notifier.ID, err)
}
}
return nil
}

View File

@@ -6,9 +6,9 @@ import (
"github.com/google/uuid"
)
func CreateTestNotifier(userID uuid.UUID) *Notifier {
func CreateTestNotifier(workspaceID uuid.UUID) *Notifier {
notifier := &Notifier{
UserID: userID,
WorkspaceID: workspaceID,
Name: "test " + uuid.New().String(),
NotifierType: NotifierTypeWebhook,
WebhookNotifier: &webhook_notifier.WebhookNotifier{

View File

@@ -2,7 +2,7 @@ package restores
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -10,7 +10,6 @@ import (
type RestoreController struct {
restoreService *RestoreService
userService *users.UserService
}
func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,24 +28,18 @@ func (c *RestoreController) RegisterRoutes(router *gin.RouterGroup) {
// @Failure 401
// @Router /restores/{backupId} [get]
func (c *RestoreController) GetRestores(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
backupID, err := uuid.Parse(ctx.Param("backupId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
restores, err := c.restoreService.GetRestores(user, backupID)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
@@ -66,6 +59,12 @@ func (c *RestoreController) GetRestores(ctx *gin.Context) {
// @Failure 401
// @Router /restores/{backupId}/restore [post]
func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
backupID, err := uuid.Parse(ctx.Param("backupId"))
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
@@ -78,18 +77,6 @@ func (c *RestoreController) RestoreBackup(ctx *gin.Context) {
return
}
authorizationHeader := ctx.GetHeader("Authorization")
if authorizationHeader == "" {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
return
}
user, err := c.userService.GetUserFromToken(authorizationHeader)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
return
}
if err := c.restoreService.RestoreBackupWithAuth(user, backupID, requestDTO); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return

View File

@@ -0,0 +1,346 @@
package restores
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/databases/databases/postgresql"
"postgresus-backend/internal/features/restores/models"
"postgresus-backend/internal/features/storages"
local_storage "postgresus-backend/internal/features/storages/models/local"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_models "postgresus-backend/internal/features/workspaces/models"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"postgresus-backend/internal/util/tools"
)
func createTestRouter() *gin.Engine {
router := workspaces_testing.CreateTestRouter(
workspaces_controllers.GetWorkspaceController(),
workspaces_controllers.GetMembershipController(),
databases.GetDatabaseController(),
backups_config.GetBackupConfigController(),
backups.GetBackupController(),
GetRestoreController(),
)
return router
}
func Test_GetRestores_WhenUserIsWorkspaceMember_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
var restores []*models.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&restores,
)
assert.NotNil(t, restores)
assert.Equal(t, 0, len(restores))
assert.NotNil(t, database)
}
func Test_GetRestores_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
testResp := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+nonMember.Token,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_GetRestores_WhenUserIsGlobalAdmin_RestoresReturned(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
var restores []*models.Restore
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/restores/%s", backup.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
&restores,
)
assert.NotNil(t, restores)
}
func Test_RestoreBackup_WhenUserIsWorkspaceMember_RestoreInitiated(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
assert.Contains(t, string(testResp.Body), "restore started successfully")
}
func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
testResp := test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+nonMember.Token,
request,
http.StatusBadRequest,
)
assert.Contains(t, string(testResp.Body), "insufficient permissions")
}
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
router := createTestRouter()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
database, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
request := RestoreBackupRequest{
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
},
}
test_utils.MakePostRequest(
t,
router,
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
"Bearer "+owner.Token,
request,
http.StatusOK,
)
time.Sleep(100 * time.Millisecond)
auditLogService := audit_logs.GetAuditLogService()
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
workspace.ID,
&audit_logs.GetAuditLogsRequest{
Limit: 100,
Offset: 0,
},
)
assert.NoError(t, err)
found := false
for _, log := range auditLogs.AuditLogs {
if strings.Contains(log.Message, "Database restored from backup") &&
strings.Contains(log.Message, database.Name) {
found = true
break
}
}
assert.True(t, found, "Audit log for restore not found")
}
func createTestDatabaseWithBackupForRestore(
workspace *workspaces_models.Workspace,
owner *users_dto.SignInResponseDTO,
router *gin.Engine,
) (*databases.Database, *backups.Backup) {
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
storage := createTestStorage(workspace.ID)
configService := backups_config.GetBackupConfigService()
config, err := configService.GetBackupConfigByDbId(database.ID)
if err != nil {
panic(err)
}
config.IsBackupsEnabled = true
config.StorageID = &storage.ID
config.Storage = storage
_, err = configService.SaveBackupConfig(config)
if err != nil {
panic(err)
}
backup := createTestBackup(database, owner)
return database, backup
}
func createTestDatabase(
name string,
workspaceID uuid.UUID,
token string,
router *gin.Engine,
) *databases.Database {
testDbName := "test_db"
request := databases.Database{
WorkspaceID: &workspaceID,
Name: name,
Type: databases.DatabaseTypePostgres,
Postgresql: &postgresql.PostgresqlDatabase{
Version: tools.PostgresqlVersion16,
Host: "localhost",
Port: 5432,
Username: "postgres",
Password: "postgres",
Database: &testDbName,
},
}
w := workspaces_testing.MakeAPIRequest(
router,
"POST",
"/api/v1/databases/create",
"Bearer "+token,
request,
)
if w.Code != http.StatusCreated {
panic(
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
)
}
var database databases.Database
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
panic(err)
}
return &database
}
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
storage := &storages.Storage{
WorkspaceID: workspaceID,
Type: storages.StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
}
repo := &storages.StorageRepository{}
storage, err := repo.Save(storage)
if err != nil {
panic(err)
}
return storage
}
func createTestBackup(
database *databases.Database,
owner *users_dto.SignInResponseDTO,
) *backups.Backup {
userService := users_services.GetUserService()
user, err := userService.GetUserFromToken(owner.Token)
if err != nil {
panic(err)
}
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
if err != nil || len(storages) == 0 {
panic("No storage found for workspace")
}
backup := &backups.Backup{
ID: uuid.New(),
DatabaseID: database.ID,
StorageID: storages[0].ID,
Status: backups.BackupStatusCompleted,
BackupSizeMb: 10.5,
BackupDurationMs: 1000,
CreatedAt: time.Now().UTC(),
}
repo := &backups.BackupRepository{}
if err := repo.Save(backup); err != nil {
panic(err)
}
dummyContent := []byte("dummy backup content for testing")
reader := strings.NewReader(string(dummyContent))
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
if err := storages[0].SaveFile(logger, backup.ID, reader); err != nil {
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
}
return backup
}

View File

@@ -1,12 +1,13 @@
package restores
import (
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
"postgresus-backend/internal/features/restores/usecases"
"postgresus-backend/internal/features/storages"
"postgresus-backend/internal/features/users"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/logger"
)
@@ -19,10 +20,11 @@ var restoreService = &RestoreService{
usecases.GetRestoreBackupUsecase(),
databases.GetDatabaseService(),
logger.GetLogger(),
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
}
var restoreController = &RestoreController{
restoreService,
users.GetUserService(),
}
var restoreBackgroundService = &RestoreBackgroundService{

View File

@@ -62,8 +62,6 @@ func (r *RestoreRepository) FindByStatus(status enums.RestoreStatus) ([]*models.
if err := storage.
GetDb().
Preload("Backup.Storage").
Preload("Backup.Database").
Preload("Backup").
Preload("Postgresql").
Where("status = ?", status).

View File

@@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log/slog"
audit_logs "postgresus-backend/internal/features/audit_logs"
"postgresus-backend/internal/features/backups/backups"
backups_config "postgresus-backend/internal/features/backups/config"
"postgresus-backend/internal/features/databases"
@@ -12,6 +13,7 @@ import (
"postgresus-backend/internal/features/restores/usecases"
"postgresus-backend/internal/features/storages"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"postgresus-backend/internal/util/tools"
"time"
@@ -26,6 +28,8 @@ type RestoreService struct {
restoreBackupUsecase *usecases.RestoreBackupUsecase
databaseService *databases.DatabaseService
logger *slog.Logger
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
}
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
@@ -58,8 +62,24 @@ func (s *RestoreService) GetRestores(
return nil, err
}
if backup.Database.UserID != user.ID {
return nil, errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return nil, err
}
if database.WorkspaceID == nil {
return nil, errors.New("cannot get restores for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return nil, err
}
if !canAccess {
return nil, errors.New("insufficient permissions to access restores for this backup")
}
return s.restoreRepository.FindByBackupID(backupID)
@@ -75,8 +95,24 @@ func (s *RestoreService) RestoreBackupWithAuth(
return err
}
if backup.Database.UserID != user.ID {
return errors.New("user does not have access to this backup")
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.WorkspaceID == nil {
return errors.New("cannot restore backup for database without workspace")
}
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
*database.WorkspaceID,
user,
)
if err != nil {
return err
}
if !canAccess {
return errors.New("insufficient permissions to restore this backup")
}
backupDatabase, err := s.databaseService.GetDatabase(user, backup.DatabaseID)
@@ -105,6 +141,16 @@ func (s *RestoreService) RestoreBackupWithAuth(
}
}()
s.auditLogService.WriteAuditLog(
fmt.Sprintf(
"Database restored from backup %s for database: %s",
backupID.String(),
database.Name,
),
&user.ID,
database.WorkspaceID,
)
return nil
}
@@ -116,7 +162,12 @@ func (s *RestoreService) RestoreBackup(
return errors.New("backup is not completed")
}
if backup.Database.Type == databases.DatabaseTypePostgres {
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
if err != nil {
return err
}
if database.Type == databases.DatabaseTypePostgres {
if requestDTO.PostgresqlDatabase == nil {
return errors.New("postgresql database is required")
}
@@ -157,7 +208,7 @@ func (s *RestoreService) RestoreBackup(
}
backupConfig, err := s.backupConfigService.GetBackupConfigByDbId(
backup.Database.ID,
database.ID,
)
if err != nil {
return err
@@ -168,6 +219,7 @@ func (s *RestoreService) RestoreBackup(
err = s.restoreBackupUsecase.Execute(
backupConfig,
restore,
database,
backup,
storage,
)

View File

@@ -31,12 +31,13 @@ type RestorePostgresqlBackupUsecase struct {
}
func (uc *RestorePostgresqlBackupUsecase) Execute(
database *databases.Database,
backupConfig *backups_config.BackupConfig,
restore models.Restore,
backup *backups.Backup,
storage *storages.Storage,
) error {
if backup.Database.Type != databases.DatabaseTypePostgres {
if database.Type != databases.DatabaseTypePostgres {
return errors.New("database type not supported")
}
@@ -76,6 +77,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
}
return uc.restoreFromStorage(
database,
tools.GetPostgresqlExecutable(
pg.Version,
"pg_restore",
@@ -92,6 +94,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
// restoreFromStorage restores backup data from storage using pg_restore
func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
database *databases.Database,
pgBin string,
args []string,
password string,
@@ -164,7 +167,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
// Add the temporary backup file as the last argument to pg_restore
args = append(args, tempBackupFile)
return uc.executePgRestore(ctx, pgBin, args, pgpassFile, pgConfig, backup)
return uc.executePgRestore(ctx, database, pgBin, args, pgpassFile, pgConfig)
}
// downloadBackupToTempFile downloads backup data from storage to a temporary file
@@ -240,11 +243,11 @@ func (uc *RestorePostgresqlBackupUsecase) downloadBackupToTempFile(
// executePgRestore executes the pg_restore command with proper environment setup
func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
ctx context.Context,
database *databases.Database,
pgBin string,
args []string,
pgpassFile string,
pgConfig *pgtypes.PostgresqlDatabase,
backup *backups.Backup,
) error {
cmd := exec.CommandContext(ctx, pgBin, args...)
uc.logger.Info("Executing PostgreSQL restore command", "command", cmd.String())
@@ -293,7 +296,7 @@ func (uc *RestorePostgresqlBackupUsecase) executePgRestore(
return fmt.Errorf("restore cancelled due to shutdown")
}
return uc.handlePgRestoreError(waitErr, stderrOutput, pgBin, args, backup, pgConfig)
return uc.handlePgRestoreError(database, waitErr, stderrOutput, pgBin, args, pgConfig)
}
return nil
@@ -341,11 +344,11 @@ func (uc *RestorePostgresqlBackupUsecase) setupPgRestoreEnvironment(
// handlePgRestoreError processes and formats pg_restore errors
func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
database *databases.Database,
waitErr error,
stderrOutput []byte,
pgBin string,
args []string,
backup *backups.Backup,
pgConfig *pgtypes.PostgresqlDatabase,
) error {
// Enhanced error handling for PostgreSQL connection and restore issues
@@ -416,8 +419,8 @@ func (uc *RestorePostgresqlBackupUsecase) handlePgRestoreError(
)
} else if containsIgnoreCase(stderrStr, "database") && containsIgnoreCase(stderrStr, "does not exist") {
backupDbName := "unknown"
if backup.Database != nil && backup.Database.Postgresql != nil && backup.Database.Postgresql.Database != nil {
backupDbName = *backup.Database.Postgresql.Database
if database.Postgresql != nil && database.Postgresql.Database != nil {
backupDbName = *database.Postgresql.Database
}
targetDbName := "unknown"

View File

@@ -17,11 +17,13 @@ type RestoreBackupUsecase struct {
func (uc *RestoreBackupUsecase) Execute(
backupConfig *backups_config.BackupConfig,
restore models.Restore,
database *databases.Database,
backup *backups.Backup,
storage *storages.Storage,
) error {
if restore.Backup.Database.Type == databases.DatabaseTypePostgres {
if database.Type == databases.DatabaseTypePostgres {
return uc.restorePostgresqlBackupUsecase.Execute(
database,
backupConfig,
restore,
backup,

View File

@@ -2,15 +2,16 @@ package storages
import (
"net/http"
"postgresus-backend/internal/features/users"
users_middleware "postgresus-backend/internal/features/users/middleware"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type StorageController struct {
storageService *StorageService
userService *users.UserService
storageService *StorageService
workspaceService *workspaces_services.WorkspaceService
}
func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
@@ -29,35 +30,40 @@ func (c *StorageController) RegisterRoutes(router *gin.RouterGroup) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param storage body Storage true "Storage data"
// @Param request body Storage true "Storage data with workspaceId"
// @Success 200 {object} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages [post]
func (c *StorageController) SaveStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var storage Storage
if err := ctx.ShouldBindJSON(&storage); err != nil {
var request Storage
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := storage.Validate(); err != nil {
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
if err := c.storageService.SaveStorage(user, request.WorkspaceID, &request); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := c.storageService.SaveStorage(user, &storage); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, storage)
ctx.JSON(http.StatusOK, request)
}
// GetStorage
@@ -70,11 +76,12 @@ func (c *StorageController) SaveStorage(ctx *gin.Context) {
// @Success 200 {object} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id} [get]
func (c *StorageController) GetStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -86,6 +93,10 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
storage, err := c.storageService.GetStorage(user, id)
if err != nil {
if err.Error() == "insufficient permissions to view storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -95,22 +106,41 @@ func (c *StorageController) GetStorage(ctx *gin.Context) {
// GetStorages
// @Summary Get all storages
// @Description Get all storages for the current user
// @Description Get all storages for a workspace
// @Tags storages
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param workspace_id query string true "Workspace ID"
// @Success 200 {array} Storage
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages [get]
func (c *StorageController) GetStorages(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
storages, err := c.storageService.GetStorages(user)
workspaceIDStr := ctx.Query("workspace_id")
if workspaceIDStr == "" {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
return
}
workspaceID, err := uuid.Parse(workspaceIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
return
}
storages, err := c.storageService.GetStorages(user, workspaceID)
if err != nil {
if err.Error() == "insufficient permissions to view storages in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -128,11 +158,12 @@ func (c *StorageController) GetStorages(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id} [delete]
func (c *StorageController) DeleteStorage(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -143,6 +174,10 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
}
if err := c.storageService.DeleteStorage(user, id); err != nil {
if err.Error() == "insufficient permissions to manage storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -160,11 +195,12 @@ func (c *StorageController) DeleteStorage(ctx *gin.Context) {
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/{id}/test [post]
func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
@@ -175,6 +211,10 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
}
if err := c.storageService.TestStorageConnection(user, id); err != nil {
if err.Error() == "insufficient permissions to test storage in this workspace" {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -189,33 +229,44 @@ func (c *StorageController) TestStorageConnection(ctx *gin.Context) {
// @Accept json
// @Produce json
// @Param Authorization header string true "JWT token"
// @Param storage body Storage true "Storage data"
// @Param request body Storage true "Storage data with workspaceId"
// @Success 200
// @Failure 400
// @Failure 401
// @Failure 403
// @Router /storages/direct-test [post]
func (c *StorageController) TestStorageConnectionDirect(ctx *gin.Context) {
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
user, ok := users_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
var request Storage
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if request.WorkspaceID == uuid.Nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
return
}
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
if err != nil {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
var storage Storage
if err := ctx.ShouldBindJSON(&storage); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// For direct test, associate with the current user
storage.UserID = user.ID
if err := storage.Validate(); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
if !canView {
ctx.JSON(
http.StatusForbidden,
gin.H{"error": "insufficient permissions to test storage in this workspace"},
)
return
}
if err := c.storageService.TestStorageConnectionDirect(&storage); err != nil {
if err := c.storageService.TestStorageConnectionDirect(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}

View File

@@ -1,25 +1,41 @@
package storages
import (
"fmt"
"net/http"
local_storage "postgresus-backend/internal/features/storages/models/local"
"postgresus-backend/internal/features/users"
test_utils "postgresus-backend/internal/util/testing"
"testing"
audit_logs "postgresus-backend/internal/features/audit_logs"
local_storage "postgresus-backend/internal/features/storages/models/local"
s3_storage "postgresus-backend/internal/features/storages/models/s3"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
)
func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
verifyStorageData(t, storage, &savedStorage)
@@ -30,8 +46,8 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/storages/"+savedStorage.ID.String(),
user.Token,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
@@ -41,181 +57,408 @@ func Test_SaveNewStorage_StorageReturnedViaGet(t *testing.T) {
// Verify storage is returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&storages,
)
assert.Contains(t, storages, savedStorage)
RemoveTestStorage(savedStorage.ID)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UpdateExistingStorage_UpdatedStorageReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
// Save initial storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Modify storage name
updatedName := "Updated Storage " + uuid.New().String()
savedStorage.Name = updatedName
// Update storage
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, savedStorage, http.StatusOK, &updatedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
// Verify updated data
assert.Equal(t, updatedName, updatedStorage.Name)
assert.Equal(t, savedStorage.ID, updatedStorage.ID)
// Verify through GET
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/storages/"+updatedStorage.ID.String(),
user.Token,
http.StatusOK,
&retrievedStorage,
)
verifyStorageData(t, &updatedStorage, &retrievedStorage)
// Verify storage is returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
)
assert.Contains(t, storages, updatedStorage)
RemoveTestStorage(updatedStorage.ID)
deleteStorage(t, router, updatedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_DeleteStorage_StorageNotReturnedViaGet(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
// Save initial storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Delete storage
test_utils.MakeDeleteRequest(
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusOK,
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
)
// Try to get deleted storage, should return error
response := test_utils.MakeGetRequest(
t, router, "/api/v1/storages/"+savedStorage.ID.String(), user.Token, http.StatusBadRequest,
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusBadRequest,
)
assert.Contains(t, string(response.Body), "error")
// Verify storage is not returned via GET all storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, http.StatusOK, &storages,
)
assert.NotContains(t, storages, savedStorage)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_TestDirectStorageConnection_ConnectionEstablished(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
response := test_utils.MakePostRequest(
t, router, "/api/v1/storages/direct-test", user.Token, storage, http.StatusOK,
t, router, "/api/v1/storages/direct-test", "Bearer "+owner.Token, *storage, http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_TestExistingStorageConnection_ConnectionEstablished(t *testing.T) {
user := users.GetTestUser()
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(user.UserID)
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t, router, "/api/v1/storages", user.Token, storage, http.StatusOK, &savedStorage,
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Test connection to existing storage
response := test_utils.MakePostRequest(
t,
router,
"/api/v1/storages/"+savedStorage.ID.String()+"/test",
user.Token,
fmt.Sprintf("/api/v1/storages/%s/test", savedStorage.ID.String()),
"Bearer "+owner.Token,
nil,
http.StatusOK,
)
assert.Contains(t, string(response.Body), "successful")
RemoveTestStorage(savedStorage.ID)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CallAllMethodsWithoutAuth_UnauthorizedErrorReturned(t *testing.T) {
func Test_ViewerCanViewStorages_ButCannotModify(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
viewer := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
storage := createNewStorage(uuid.New())
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
viewer,
users_enums.WorkspaceRoleViewer,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Test endpoints without auth
endpoints := []struct {
method string
url string
body interface{}
}{
{"GET", "/api/v1/storages", nil},
{"GET", "/api/v1/storages/" + uuid.New().String(), nil},
{"POST", "/api/v1/storages", storage},
{"DELETE", "/api/v1/storages/" + uuid.New().String(), nil},
{"POST", "/api/v1/storages/" + uuid.New().String() + "/test", nil},
{"POST", "/api/v1/storages/direct-test", storage},
}
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
for _, endpoint := range endpoints {
testUnauthorizedEndpoint(t, router, endpoint.method, endpoint.url, endpoint.body)
}
// Viewer can GET storages
var storages []Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+viewer.Token,
http.StatusOK,
&storages,
)
assert.Len(t, storages, 1)
// Viewer cannot CREATE storage
newStorage := createNewStorage(workspace.ID)
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, *newStorage, http.StatusForbidden,
)
// Viewer cannot UPDATE storage
savedStorage.Name = "Updated by viewer"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+viewer.Token, savedStorage, http.StatusForbidden,
)
// Viewer cannot DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+viewer.Token,
http.StatusForbidden,
)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func testUnauthorizedEndpoint(
t *testing.T,
router *gin.Engine,
method, url string,
body interface{},
) {
test_utils.MakeRequest(t, router, test_utils.RequestOptions{
Method: method,
URL: url,
Body: body,
ExpectedStatus: http.StatusUnauthorized,
})
func Test_MemberCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
member,
users_enums.WorkspaceRoleMember,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Member can CREATE storage
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
*storage,
http.StatusOK,
&savedStorage,
)
assert.NotEmpty(t, savedStorage.ID)
// Member can UPDATE storage
savedStorage.Name = "Updated by member"
var updatedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+member.Token,
savedStorage,
http.StatusOK,
&updatedStorage,
)
assert.Equal(t, "Updated by member", updatedStorage.Name)
// Member can DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+member.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_AdminCanManageStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
admin := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
workspaces_testing.AddMemberToWorkspace(
workspace,
admin,
users_enums.WorkspaceRoleAdmin,
owner.Token,
router,
)
storage := createNewStorage(workspace.ID)
// Admin can CREATE, UPDATE, DELETE
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+admin.Token,
*storage,
http.StatusOK,
&savedStorage,
)
savedStorage.Name = "Updated by admin"
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+admin.Token, savedStorage, http.StatusOK,
)
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+admin.Token,
http.StatusOK,
)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_UserNotInWorkspace_CannotAccessStorages(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
outsider := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
storage := createNewStorage(workspace.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*storage,
http.StatusOK,
&savedStorage,
)
// Outsider cannot GET storages
test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/storages?workspace_id=%s", workspace.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
// Outsider cannot CREATE storage
test_utils.MakePostRequest(
t, router, "/api/v1/storages", "Bearer "+outsider.Token, *storage, http.StatusForbidden,
)
// Outsider cannot UPDATE storage
test_utils.MakePostRequest(
t,
router,
"/api/v1/storages",
"Bearer "+outsider.Token,
savedStorage,
http.StatusForbidden,
)
// Outsider cannot DELETE storage
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+outsider.Token,
http.StatusForbidden,
)
deleteStorage(t, router, savedStorage.ID, workspace.ID, owner.Token)
workspaces_testing.RemoveTestWorkspace(workspace, router)
}
func Test_CrossWorkspaceSecurity_CannotAccessStorageFromAnotherWorkspace(t *testing.T) {
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace1 := workspaces_testing.CreateTestWorkspace("Workspace 1", owner1, router)
workspace2 := workspaces_testing.CreateTestWorkspace("Workspace 2", owner2, router)
storage1 := createNewStorage(workspace1.ID)
var savedStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner1.Token,
*storage1,
http.StatusOK,
&savedStorage,
)
// Try to access workspace1's storage with owner2 from workspace2
response := test_utils.MakeGetRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", savedStorage.ID.String()),
"Bearer "+owner2.Token,
http.StatusForbidden,
)
assert.Contains(t, string(response.Body), "insufficient permissions")
deleteStorage(t, router, savedStorage.ID, workspace1.ID, owner1.Token)
workspaces_testing.RemoveTestWorkspace(workspace1, router)
workspaces_testing.RemoveTestWorkspace(workspace2, router)
}
func createRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
controller := GetStorageController()
v1 := router.Group("/api/v1")
controller.RegisterRoutes(v1)
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
GetStorageController().RegisterRoutes(routerGroup)
workspaces_controllers.GetWorkspaceController().RegisterRoutes(routerGroup)
workspaces_controllers.GetMembershipController().RegisterRoutes(routerGroup)
}
audit_logs.SetupDependencies()
return router
}
func createNewStorage(userID uuid.UUID) *Storage {
func createNewStorage(workspaceID uuid.UUID) *Storage {
return &Storage{
UserID: userID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},
@@ -225,5 +468,175 @@ func createNewStorage(userID uuid.UUID) *Storage {
func verifyStorageData(t *testing.T, expected *Storage, actual *Storage) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Type, actual.Type)
assert.Equal(t, expected.UserID, actual.UserID)
assert.Equal(t, expected.WorkspaceID, actual.WorkspaceID)
}
func deleteStorage(
t *testing.T,
router *gin.Engine,
storageID, workspaceID uuid.UUID,
token string,
) {
test_utils.MakeDeleteRequest(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", storageID.String()),
"Bearer "+token,
http.StatusOK,
)
}
func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
testCases := []struct {
name string
storageType StorageType
createStorage func(workspaceID uuid.UUID) *Storage
updateStorage func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage
verifySensitiveData func(t *testing.T, storage *Storage)
verifyHiddenData func(t *testing.T, storage *Storage)
}{
{
name: "S3 Storage",
storageType: StorageTypeS3,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Test S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "test-bucket",
S3Region: "us-east-1",
S3AccessKey: "original-access-key",
S3SecretKey: "original-secret-key",
S3Endpoint: "https://s3.amazonaws.com",
},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeS3,
Name: "Updated S3 Storage",
S3Storage: &s3_storage.S3Storage{
S3Bucket: "updated-bucket",
S3Region: "us-west-2",
S3AccessKey: "",
S3SecretKey: "",
S3Endpoint: "https://s3.us-west-2.amazonaws.com",
},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "original-access-key", storage.S3Storage.S3AccessKey)
assert.Equal(t, "original-secret-key", storage.S3Storage.S3SecretKey)
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
assert.Equal(t, "", storage.S3Storage.S3AccessKey)
assert.Equal(t, "", storage.S3Storage.S3SecretKey)
},
},
{
name: "Local Storage",
storageType: StorageTypeLocal,
createStorage: func(workspaceID uuid.UUID) *Storage {
return &Storage{
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
return &Storage{
ID: storageID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Updated Local Storage",
LocalStorage: &local_storage.LocalStorage{},
}
},
verifySensitiveData: func(t *testing.T, storage *Storage) {
},
verifyHiddenData: func(t *testing.T, storage *Storage) {
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
router := createRouter()
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
// Phase 1: Create storage with sensitive data
initialStorage := tc.createStorage(workspace.ID)
var createdStorage Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*initialStorage,
http.StatusOK,
&createdStorage,
)
assert.NotEmpty(t, createdStorage.ID)
assert.Equal(t, initialStorage.Name, createdStorage.Name)
// Phase 2: Read via service - sensitive data should be hidden
var retrievedStorage Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&retrievedStorage,
)
tc.verifyHiddenData(t, &retrievedStorage)
assert.Equal(t, initialStorage.Name, retrievedStorage.Name)
// Phase 3: Update with non-sensitive changes only (sensitive fields empty)
updatedStorage := tc.updateStorage(workspace.ID, createdStorage.ID)
var updateResponse Storage
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/storages",
"Bearer "+owner.Token,
*updatedStorage,
http.StatusOK,
&updateResponse,
)
// Verify non-sensitive fields were updated
assert.Equal(t, updatedStorage.Name, updateResponse.Name)
// Phase 4: Retrieve directly from repository to verify sensitive data preservation
repository := &StorageRepository{}
storageFromDB, err := repository.FindByID(createdStorage.ID)
assert.NoError(t, err)
// Verify original sensitive data is still present in DB
tc.verifySensitiveData(t, storageFromDB)
// Verify non-sensitive fields were updated in DB
assert.Equal(t, updatedStorage.Name, storageFromDB.Name)
// Additional verification: Check via GET that data is still hidden
var finalRetrieved Storage
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
fmt.Sprintf("/api/v1/storages/%s", createdStorage.ID.String()),
"Bearer "+owner.Token,
http.StatusOK,
&finalRetrieved,
)
tc.verifyHiddenData(t, &finalRetrieved)
})
}
}

View File

@@ -1,16 +1,19 @@
package storages
import (
"postgresus-backend/internal/features/users"
audit_logs "postgresus-backend/internal/features/audit_logs"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
)
var storageRepository = &StorageRepository{}
var storageService = &StorageService{
storageRepository,
workspaces_services.GetWorkspaceService(),
audit_logs.GetAuditLogService(),
}
var storageController = &StorageController{
storageService,
users.GetUserService(),
workspaces_services.GetWorkspaceService(),
}
func GetStorageService() *StorageService {
@@ -20,3 +23,7 @@ func GetStorageService() *StorageService {
func GetStorageController() *StorageController {
return storageController
}
func SetupDependencies() {
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(storageService)
}

View File

@@ -17,4 +17,6 @@ type StorageFileSaver interface {
Validate() error
TestConnection() error
HideSensitiveData()
}

View File

@@ -14,7 +14,7 @@ import (
type Storage struct {
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null;type:uuid;index"`
WorkspaceID uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;not null;type:uuid;index"`
Type StorageType `json:"type" gorm:"column:type;not null;type:text"`
Name string `json:"name" gorm:"column:name;not null;type:text"`
LastSaveError *string `json:"lastSaveError" gorm:"column:last_save_error;type:text"`
@@ -63,6 +63,34 @@ func (s *Storage) TestConnection() error {
return s.getSpecificStorage().TestConnection()
}
func (s *Storage) HideSensitiveData() {
s.getSpecificStorage().HideSensitiveData()
}
func (s *Storage) Update(incoming *Storage) {
s.Name = incoming.Name
s.Type = incoming.Type
switch s.Type {
case StorageTypeLocal:
if s.LocalStorage != nil && incoming.LocalStorage != nil {
s.LocalStorage.Update(incoming.LocalStorage)
}
case StorageTypeS3:
if s.S3Storage != nil && incoming.S3Storage != nil {
s.S3Storage.Update(incoming.S3Storage)
}
case StorageTypeGoogleDrive:
if s.GoogleDriveStorage != nil && incoming.GoogleDriveStorage != nil {
s.GoogleDriveStorage.Update(incoming.GoogleDriveStorage)
}
case StorageTypeNAS:
if s.NASStorage != nil && incoming.NASStorage != nil {
s.NASStorage.Update(incoming.NASStorage)
}
}
}
func (s *Storage) getSpecificStorage() StorageFileSaver {
switch s.Type {
case StorageTypeLocal:

View File

@@ -191,6 +191,23 @@ func (s *GoogleDriveStorage) TestConnection() error {
})
}
func (s *GoogleDriveStorage) HideSensitiveData() {
s.ClientSecret = ""
s.TokenJSON = ""
}
func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
s.ClientID = incoming.ClientID
if incoming.ClientSecret != "" {
s.ClientSecret = incoming.ClientSecret
}
if incoming.TokenJSON != "" {
s.TokenJSON = incoming.TokenJSON
}
}
// withRetryOnAuth executes the provided function with retry logic for authentication errors
func (s *GoogleDriveStorage) withRetryOnAuth(fn func(*drive.Service) error) error {
driveService, err := s.getDriveService()

View File

@@ -156,3 +156,9 @@ func (l *LocalStorage) TestConnection() error {
return nil
}
func (l *LocalStorage) HideSensitiveData() {
}
func (l *LocalStorage) Update(incoming *LocalStorage) {
}

View File

@@ -251,6 +251,24 @@ func (n *NASStorage) TestConnection() error {
return nil
}
func (n *NASStorage) HideSensitiveData() {
n.Password = ""
}
func (n *NASStorage) Update(incoming *NASStorage) {
n.Host = incoming.Host
n.Port = incoming.Port
n.Share = incoming.Share
n.Username = incoming.Username
n.UseSSL = incoming.UseSSL
n.Domain = incoming.Domain
n.Path = incoming.Path
if incoming.Password != "" {
n.Password = incoming.Password
}
}
func (n *NASStorage) createSession() (*smb2.Session, error) {
// Create connection with timeout
conn, err := n.createConnection()

View File

@@ -180,6 +180,25 @@ func (s *S3Storage) TestConnection() error {
return nil
}
func (s *S3Storage) HideSensitiveData() {
s.S3AccessKey = ""
s.S3SecretKey = ""
}
func (s *S3Storage) Update(incoming *S3Storage) {
s.S3Bucket = incoming.S3Bucket
s.S3Region = incoming.S3Region
s.S3Endpoint = incoming.S3Endpoint
if incoming.S3AccessKey != "" {
s.S3AccessKey = incoming.S3AccessKey
}
if incoming.S3SecretKey != "" {
s.S3SecretKey = incoming.S3SecretKey
}
}
func (s *S3Storage) getClient() (*minio.Client, error) {
endpoint := s.S3Endpoint
useSSL := true

View File

@@ -104,7 +104,7 @@ func (r *StorageRepository) FindByID(id uuid.UUID) (*Storage, error) {
return &s, nil
}
func (r *StorageRepository) FindByUserID(userID uuid.UUID) ([]*Storage, error) {
func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage, error) {
var storages []*Storage
if err := db.
@@ -113,7 +113,7 @@ func (r *StorageRepository) FindByUserID(userID uuid.UUID) ([]*Storage, error) {
Preload("S3Storage").
Preload("GoogleDriveStorage").
Preload("NASStorage").
Where("user_id = ?", userID).
Where("workspace_id = ?", workspaceID).
Order("name ASC").
Find(&storages).Error; err != nil {
return nil, err

View File

@@ -2,37 +2,79 @@ package storages
import (
"errors"
"fmt"
audit_logs "postgresus-backend/internal/features/audit_logs"
users_models "postgresus-backend/internal/features/users/models"
workspaces_services "postgresus-backend/internal/features/workspaces/services"
"github.com/google/uuid"
)
type StorageService struct {
storageRepository *StorageRepository
workspaceService *workspaces_services.WorkspaceService
auditLogService *audit_logs.AuditLogService
}
func (s *StorageService) SaveStorage(
user *users_models.User,
workspaceID uuid.UUID,
storage *Storage,
) error {
if storage.ID != uuid.Nil {
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage storage in this workspace")
}
isUpdate := storage.ID != uuid.Nil
if isUpdate {
existingStorage, err := s.storageRepository.FindByID(storage.ID)
if err != nil {
return err
}
if existingStorage.UserID != user.ID {
return errors.New("you have not access to this storage")
if existingStorage.WorkspaceID != workspaceID {
return errors.New("storage does not belong to this workspace")
}
storage.UserID = existingStorage.UserID
} else {
storage.UserID = user.ID
}
existingStorage.Update(storage)
_, err := s.storageRepository.Save(storage)
if err != nil {
return err
if err := existingStorage.Validate(); err != nil {
return err
}
_, err = s.storageRepository.Save(existingStorage)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage updated: %s", existingStorage.Name),
&user.ID,
&workspaceID,
)
} else {
storage.WorkspaceID = workspaceID
if err := storage.Validate(); err != nil {
return err
}
_, err = s.storageRepository.Save(storage)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage created: %s", storage.Name),
&user.ID,
&workspaceID,
)
}
return nil
@@ -47,11 +89,26 @@ func (s *StorageService) DeleteStorage(
return err
}
if storage.UserID != user.ID {
return errors.New("you have not access to this storage")
canManage, err := s.workspaceService.CanUserManageDBs(storage.WorkspaceID, user)
if err != nil {
return err
}
if !canManage {
return errors.New("insufficient permissions to manage storage in this workspace")
}
return s.storageRepository.Delete(storage)
err = s.storageRepository.Delete(storage)
if err != nil {
return err
}
s.auditLogService.WriteAuditLog(
fmt.Sprintf("Storage deleted: %s", storage.Name),
&user.ID,
&storage.WorkspaceID,
)
return nil
}
func (s *StorageService) GetStorage(
@@ -63,17 +120,41 @@ func (s *StorageService) GetStorage(
return nil, err
}
if storage.UserID != user.ID {
return nil, errors.New("you have not access to this storage")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view storage in this workspace")
}
storage.HideSensitiveData()
return storage, nil
}
func (s *StorageService) GetStorages(
user *users_models.User,
workspaceID uuid.UUID,
) ([]*Storage, error) {
return s.storageRepository.FindByUserID(user.ID)
canView, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
if err != nil {
return nil, err
}
if !canView {
return nil, errors.New("insufficient permissions to view storages in this workspace")
}
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return nil, err
}
for _, storage := range storages {
storage.HideSensitiveData()
}
return storages, nil
}
func (s *StorageService) TestStorageConnection(
@@ -85,8 +166,12 @@ func (s *StorageService) TestStorageConnection(
return err
}
if storage.UserID != user.ID {
return errors.New("you have not access to this storage")
canView, _, err := s.workspaceService.CanUserAccessWorkspace(storage.WorkspaceID, user)
if err != nil {
return err
}
if !canView {
return errors.New("insufficient permissions to test storage in this workspace")
}
err = storage.TestConnection()
@@ -108,7 +193,30 @@ func (s *StorageService) TestStorageConnection(
func (s *StorageService) TestStorageConnectionDirect(
storage *Storage,
) error {
return storage.TestConnection()
var usingStorage *Storage
if storage.ID != uuid.Nil {
existingStorage, err := s.storageRepository.FindByID(storage.ID)
if err != nil {
return err
}
if existingStorage.WorkspaceID != storage.WorkspaceID {
return errors.New("storage does not belong to this workspace")
}
existingStorage.Update(storage)
if err := existingStorage.Validate(); err != nil {
return err
}
usingStorage = existingStorage
} else {
usingStorage = storage
}
return usingStorage.TestConnection()
}
func (s *StorageService) GetStorageByID(
@@ -116,3 +224,18 @@ func (s *StorageService) GetStorageByID(
) (*Storage, error) {
return s.storageRepository.FindByID(id)
}
func (s *StorageService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
storages, err := s.storageRepository.FindByWorkspaceID(workspaceID)
if err != nil {
return fmt.Errorf("failed to get storages for workspace deletion: %w", err)
}
for _, storage := range storages {
if err := s.storageRepository.Delete(storage); err != nil {
return fmt.Errorf("failed to delete storage %s: %w", storage.ID, err)
}
}
return nil
}

View File

@@ -6,9 +6,9 @@ import (
"github.com/google/uuid"
)
func CreateTestStorage(userID uuid.UUID) *Storage {
func CreateTestStorage(workspaceID uuid.UUID) *Storage {
storage := &Storage{
UserID: userID,
WorkspaceID: workspaceID,
Type: StorageTypeLocal,
Name: "Test Storage " + uuid.New().String(),
LocalStorage: &local_storage.LocalStorage{},

View File

@@ -1,6 +1,7 @@
package tests
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -68,6 +69,7 @@ func Test_BackupAndRestorePostgresql_RestoreIsSuccesful(t *testing.T) {
version string
port string
}{
{"PostgreSQL 12", "12", env.TestPostgres12Port},
{"PostgreSQL 13", "13", env.TestPostgres13Port},
{"PostgreSQL 14", "14", env.TestPostgres14Port},
{"PostgreSQL 15", "15", env.TestPostgres15Port},
@@ -129,7 +131,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
}
storage := &storages.Storage{
UserID: uuid.New(),
WorkspaceID: uuid.New(),
Type: storages.StorageTypeLocal,
Name: "Test Storage",
LocalStorage: &local_storage.LocalStorage{},
@@ -138,6 +140,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
// Make backup
progressTracker := func(completedMBs float64) {}
err = usecases_postgresql_backup.GetCreatePostgresqlBackupUsecase().Execute(
context.Background(),
backupID,
backupConfig,
backupDb,
@@ -168,8 +171,6 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
StorageID: storage.ID,
Status: backups.BackupStatusCompleted,
CreatedAt: time.Now().UTC(),
Storage: storage,
Database: backupDb,
}
restoreID := uuid.New()
@@ -189,7 +190,7 @@ func testBackupRestoreForVersion(t *testing.T, pgVersion string, port string) {
// Restore the backup
restoreBackupUC := usecases_postgresql_restore.GetRestorePostgresqlBackupUsecase()
err = restoreBackupUC.Execute(backupConfig, restore, completedBackup, storage)
err = restoreBackupUC.Execute(backupDb, backupConfig, restore, completedBackup, storage)
assert.NoError(t, err)
// Verify restored table exists

View File

@@ -1,98 +0,0 @@
package users
import (
"net/http"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
type UserController struct {
userService *UserService
signinLimiter *rate.Limiter
}
func (c *UserController) RegisterRoutes(router *gin.RouterGroup) {
router.POST("/users/signup", c.SignUp)
router.POST("/users/signin", c.SignIn)
router.GET("/users/is-any-user-exist", c.IsAnyUserExist)
}
// SignUp
// @Summary Register a new user
// @Description Register a new user with email and password
// @Tags users
// @Accept json
// @Produce json
// @Param request body SignUpRequest true "User signup data"
// @Success 200
// @Failure 400
// @Router /users/signup [post]
func (c *UserController) SignUp(ctx *gin.Context) {
var request SignUpRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
err := c.userService.SignUp(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User created successfully"})
}
// SignIn
// @Summary Authenticate a user
// @Description Authenticate a user with email and password
// @Tags users
// @Accept json
// @Produce json
// @Param request body SignInRequest true "User signin data"
// @Success 200 {object} SignInResponse
// @Failure 400
// @Failure 429 {object} map[string]string "Rate limit exceeded"
// @Router /users/signin [post]
func (c *UserController) SignIn(ctx *gin.Context) {
// We use rate limiter to prevent brute force attacks
if !c.signinLimiter.Allow() {
ctx.JSON(
http.StatusTooManyRequests,
gin.H{"error": "Rate limit exceeded. Please try again later."},
)
return
}
var request SignInRequest
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
response, err := c.userService.SignIn(&request)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, response)
}
// IsAnyUserExist
// @Summary Check if any user exists
// @Description Check if any user exists in the system
// @Tags users
// @Produce json
// @Success 200 {object} map[string]bool
// @Router /users/is-any-user-exist [get]
func (c *UserController) IsAnyUserExist(ctx *gin.Context) {
isExist, err := c.userService.IsAnyUserExist()
if err != nil {
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"isExist": isExist})
}

View File

@@ -0,0 +1,32 @@
package users_controllers
import (
users_services "postgresus-backend/internal/features/users/services"
"golang.org/x/time/rate"
)
var userController = &UserController{
users_services.GetUserService(),
rate.NewLimiter(rate.Limit(3), 3), // 3 rps with 3 burst
}
var settingsController = &SettingsController{
users_services.GetSettingsService(),
}
var managementController = &ManagementController{
users_services.GetManagementService(),
}
func GetUserController() *UserController {
return userController
}
func GetSettingsController() *SettingsController {
return settingsController
}
func GetManagementController() *ManagementController {
return managementController
}

View File

@@ -0,0 +1,263 @@
package users_controllers
import (
"net/http"
"testing"
users_dto "postgresus-backend/internal/features/users/dto"
users_enums "postgresus-backend/internal/features/users/enums"
users_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
users_testing "postgresus-backend/internal/features/users/testing"
test_utils "postgresus-backend/internal/util/testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"golang.org/x/time/rate"
)
func Test_AdminLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := createE2ETestRouter()
users_testing.RecreateInitialAdmin()
// 1. Set initial admin password
adminPasswordRequest := users_dto.SetAdminPasswordRequestDTO{
Password: "adminpassword123",
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/admin/set-password",
"",
adminPasswordRequest,
http.StatusOK,
)
// 2. Admin signs in
adminSigninRequest := users_dto.SignInRequestDTO{
Email: "admin",
Password: "adminpassword123",
}
var adminSigninResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
adminSigninRequest,
http.StatusOK,
&adminSigninResponse,
)
// 3. Admin invites a user
workspaceID := uuid.New()
workspaceRole := users_enums.WorkspaceRoleMember
invitedUserEmail := "invited" + uuid.New().String() + "@example.com"
inviteRequest := users_dto.InviteUserRequestDTO{
Email: invitedUserEmail,
IntendedWorkspaceID: &workspaceID,
IntendedWorkspaceRole: &workspaceRole,
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/invite",
"Bearer "+adminSigninResponse.Token,
inviteRequest,
http.StatusOK,
)
// 4. Invited user signs up
userSignupRequest := users_dto.SignUpRequestDTO{
Email: invitedUserEmail,
Password: "userpassword123",
Name: "Invited User",
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signup",
"",
userSignupRequest,
http.StatusOK,
)
// 5. User signs in
userSigninRequest := users_dto.SignInRequestDTO{
Email: invitedUserEmail,
Password: "userpassword123",
}
var userSigninResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
userSigninRequest,
http.StatusOK,
&userSigninResponse,
)
// 6. Admin lists users and sees new user
var listUsersResponse users_dto.ListUsersResponseDTO
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/users",
"Bearer "+adminSigninResponse.Token,
http.StatusOK,
&listUsersResponse,
)
assert.GreaterOrEqual(t, len(listUsersResponse.Users), 2) // Admin + new user
}
func Test_UserLifecycleE2E_CompletesSuccessfully(t *testing.T) {
router := createE2ETestRouter()
users_testing.ResetSettingsToDefaults()
// 1. User registers
userEmail := "testuser" + uuid.New().String() + "@example.com"
userSignupRequest := users_dto.SignUpRequestDTO{
Email: userEmail,
Password: "userpassword123",
Name: "Test User",
}
test_utils.MakePostRequest(
t,
router,
"/api/v1/users/signup",
"",
userSignupRequest,
http.StatusOK,
)
// 2. User signs in
userSigninRequest := users_dto.SignInRequestDTO{
Email: userEmail,
Password: "userpassword123",
}
var signinResponse users_dto.SignInResponseDTO
test_utils.MakePostRequestAndUnmarshal(
t,
router,
"/api/v1/users/signin",
"",
userSigninRequest,
http.StatusOK,
&signinResponse,
)
assert.NotEmpty(t, signinResponse.Token)
assert.NotEqual(t, uuid.Nil, signinResponse.UserID)
// 3. User gets own profile
var profileResponse users_dto.UserProfileResponseDTO
test_utils.MakeGetRequestAndUnmarshal(
t,
router,
"/api/v1/users/"+signinResponse.UserID.String(),
"Bearer "+signinResponse.Token,
http.StatusOK,
&profileResponse,
)
assert.Equal(t, signinResponse.UserID, profileResponse.ID)
assert.Equal(t, userEmail, profileResponse.Email)
assert.Equal(t, users_enums.UserRoleMember, profileResponse.Role)
assert.True(t, profileResponse.IsActive)
}
// Test router creation helpers
func createUserTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
// Register public routes
GetUserController().RegisterRoutes(v1)
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
GetUserController().SetSignInLimiter(rate.NewLimiter(rate.Limit(100), 100))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
return router
}
func createSettingsTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetSettingsController().RegisterRoutes(protected.(*gin.RouterGroup))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetSettingsService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetManagementService().SetAuditLogWriter(&AuditLogWriterStub{})
return router
}
func createManagementTestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetManagementController().RegisterRoutes(protected.(*gin.RouterGroup))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetSettingsService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetManagementService().SetAuditLogWriter(&AuditLogWriterStub{})
return router
}
func createE2ETestRouter() *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
v1 := router.Group("/api/v1")
// Register all routes
GetUserController().RegisterRoutes(v1)
// Register protected routes with auth middleware
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
GetUserController().RegisterProtectedRoutes(protected.(*gin.RouterGroup))
GetSettingsController().RegisterRoutes(protected.(*gin.RouterGroup))
GetManagementController().RegisterRoutes(protected.(*gin.RouterGroup))
// Setup audit log service
users_services.GetUserService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetSettingsService().SetAuditLogWriter(&AuditLogWriterStub{})
users_services.GetManagementService().SetAuditLogWriter(&AuditLogWriterStub{})
return router
}
type AuditLogWriterStub struct{}
func (a *AuditLogWriterStub) WriteAuditLog(
message string,
userID *uuid.UUID,
workspaceID *uuid.UUID,
) {
// do nothing
}

View File

@@ -0,0 +1,260 @@
package users_controllers
import (
"fmt"
"net/http"
user_dto "postgresus-backend/internal/features/users/dto"
user_enums "postgresus-backend/internal/features/users/enums"
user_middleware "postgresus-backend/internal/features/users/middleware"
users_services "postgresus-backend/internal/features/users/services"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type ManagementController struct {
managementService *users_services.UserManagementService
}
func (c *ManagementController) RegisterRoutes(router *gin.RouterGroup) {
router.GET("/users", user_middleware.RequireRole(user_enums.UserRoleAdmin), c.GetUsers)
router.GET("/users/:id", c.GetUserProfile)
router.POST(
"/users/:id/deactivate",
user_middleware.RequireRole(user_enums.UserRoleAdmin),
c.DeactivateUser,
)
router.POST(
"/users/:id/activate",
user_middleware.RequireRole(user_enums.UserRoleAdmin),
c.ActivateUser,
)
router.PUT(
"/users/:id/role",
user_middleware.RequireRole(user_enums.UserRoleAdmin),
c.ChangeUserRole,
)
}
// ListUsers
// @Summary List users
// @Description Get list of users (admin only)
// @Tags user-management
// @Produce json
// @Security BearerAuth
// @Param limit query int false "Number of items per page" default(20)
// @Param offset query int false "Page offset" default(0)
// @Param beforeDate query string false "Filter users created before this date (RFC3339 format)" format(date-time)
// @Param query query string false "Search by email or name (case-insensitive)"
// @Success 200 {object} users_dto.ListUsersResponseDTO
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 403 {object} map[string]string "Forbidden"
// @Router /users [get]
func (c *ManagementController) GetUsers(ctx *gin.Context) {
fmt.Println("GetUsers")
user, ok := user_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
request := &user_dto.ListUsersRequestDTO{}
if err := ctx.ShouldBindQuery(request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
return
}
// Set defaults if not provided
if request.Limit <= 0 || request.Limit > 100 {
request.Limit = 20
}
if request.Offset < 0 {
request.Offset = 0
}
users, total, err := c.managementService.GetUsers(
user,
request.Limit,
request.Offset,
request.BeforeDate,
request.Query,
)
if err != nil {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
// Convert to response format
userProfiles := make([]user_dto.UserProfileResponseDTO, len(users))
for i, u := range users {
userProfiles[i] = user_dto.UserProfileResponseDTO{
ID: u.ID,
Email: u.Email,
Name: u.Name,
Role: u.Role,
IsActive: u.IsActiveUser(),
CreatedAt: u.CreatedAt,
}
}
response := user_dto.ListUsersResponseDTO{
Users: userProfiles,
Total: total,
}
ctx.JSON(http.StatusOK, response)
}
// GetUserProfile
// @Summary Get user profile
// @Description Get user profile information (users can view own profile, admins can view any)
// @Tags user-management
// @Produce json
// @Security BearerAuth
// @Param id path string true "User ID"
// @Success 200 {object} users_dto.UserProfileResponseDTO
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 403 {object} map[string]string "Forbidden"
// @Router /users/{id} [get]
func (c *ManagementController) GetUserProfile(ctx *gin.Context) {
currentUser, ok := user_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
userIDStr := ctx.Param("id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
user, err := c.managementService.GetUserProfile(userID, currentUser)
if err != nil {
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
return
}
profile := user_dto.UserProfileResponseDTO{
ID: user.ID,
Email: user.Email,
Name: user.Name,
Role: user.Role,
IsActive: user.IsActiveUser(),
CreatedAt: user.CreatedAt,
}
ctx.JSON(http.StatusOK, profile)
}
// DeactivateUser
// @Summary Deactivate user
// @Description Deactivate a user account (admin only)
// @Tags user-management
// @Security BearerAuth
// @Param id path string true "User ID"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 403 {object} map[string]string "Forbidden"
// @Router /users/{id}/deactivate [post]
func (c *ManagementController) DeactivateUser(ctx *gin.Context) {
currentUser, ok := user_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
userIDStr := ctx.Param("id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
if err := c.managementService.DeactivateUser(userID, currentUser); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User deactivated successfully"})
}
// ActivateUser
// @Summary Activate user
// @Description Activate a user account (admin only)
// @Tags user-management
// @Security BearerAuth
// @Param id path string true "User ID"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 403 {object} map[string]string "Forbidden"
// @Router /users/{id}/activate [post]
func (c *ManagementController) ActivateUser(ctx *gin.Context) {
currentUser, ok := user_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
userIDStr := ctx.Param("id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
if err := c.managementService.ActivateUser(userID, currentUser); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User activated successfully"})
}
// ChangeUserRole
// @Summary Change user role
// @Description Change a user's role (admin only)
// @Tags user-management
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param id path string true "User ID"
// @Param request body users_dto.ChangeUserRoleRequestDTO true "Role change data"
// @Success 200 {object} map[string]string
// @Failure 400 {object} map[string]string "Bad request"
// @Failure 401 {object} map[string]string "Unauthorized"
// @Failure 403 {object} map[string]string "Forbidden"
// @Router /users/{id}/role [put]
func (c *ManagementController) ChangeUserRole(ctx *gin.Context) {
currentUser, ok := user_middleware.GetUserFromContext(ctx)
if !ok {
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
return
}
userIDStr := ctx.Param("id")
userID, err := uuid.Parse(userIDStr)
if err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
var request user_dto.ChangeUserRoleRequestDTO
if err := ctx.ShouldBindJSON(&request); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request format"})
return
}
if err := c.managementService.ChangeUserRole(userID, request.Role, currentUser); err != nil {
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx.JSON(http.StatusOK, gin.H{"message": "User role changed successfully"})
}

Some files were not shown because too many files have changed in this diff Show More