mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 08:41:58 +02:00
Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6cfc0ca79b | ||
|
|
5d27123bd7 | ||
|
|
79ca374bb6 | ||
|
|
b3f1a6f7e5 | ||
|
|
d521e2abc6 | ||
|
|
82eca7501b | ||
|
|
51866437fd | ||
|
|
244a56d1bb | ||
|
|
95c833b619 | ||
|
|
878fad5747 | ||
|
|
6ff3096695 | ||
|
|
b4b514c2d5 | ||
|
|
da0fec6624 | ||
|
|
408675023a | ||
|
|
0bc93389cc | ||
|
|
c8e6aea6e1 | ||
|
|
981ad21471 | ||
|
|
177a9c782c | ||
|
|
069d6bc8fe | ||
|
|
242d5543d4 | ||
|
|
02c735bc5a | ||
|
|
793b575146 | ||
|
|
a6e84b45f2 | ||
|
|
a941fbd093 | ||
|
|
4492ba41f5 | ||
|
|
3a5ac4b479 | ||
|
|
77aaabeaa1 | ||
|
|
01911dbf72 | ||
|
|
1a16f27a5d | ||
|
|
778db71625 | ||
|
|
45fc9a7fff | ||
|
|
7f5e786261 | ||
|
|
9b066bcb8a | ||
|
|
9ea795b48f | ||
|
|
a809dc8a9c | ||
|
|
bd053b51a3 | ||
|
|
431e9861f4 | ||
|
|
de1fd4c4da | ||
|
|
df55fd17d5 | ||
|
|
fcc894d1f5 | ||
|
|
7307a515e2 | ||
|
|
5f280c0d6d | ||
|
|
492605a1b0 | ||
|
|
f9eaead8a1 | ||
|
|
aad9ed6589 | ||
|
|
181c32ded3 | ||
|
|
7fb59bb5d0 | ||
|
|
dc9ddae42e | ||
|
|
a409c8ccb3 | ||
|
|
a018b0c62f | ||
|
|
97d7253dda | ||
|
|
81aadd19e1 | ||
|
|
432bdced3e | ||
|
|
fcfe382a81 | ||
|
|
7055b85c34 | ||
|
|
0abc2225de | ||
|
|
31685f7bb0 | ||
|
|
9dbcf91442 | ||
|
|
6ef59e888b | ||
|
|
2009eabb14 | ||
|
|
fa073ab76c | ||
|
|
f24b3219bc | ||
|
|
332971a014 | ||
|
|
7bb057ed2d | ||
|
|
d814c1362b | ||
|
|
41fe554272 | ||
|
|
00c93340db | ||
|
|
21770b259b | ||
|
|
5f36f269f0 | ||
|
|
76d67d6be8 | ||
|
|
7adb921812 | ||
|
|
0107dab026 | ||
|
|
dee330ed59 | ||
|
|
299f152704 | ||
|
|
f3edf1a102 | ||
|
|
f425160765 | ||
|
|
13f2d3938f | ||
|
|
59692cd41b | ||
|
|
ac78fe306c |
19
.github/workflows/ci-release.yml
vendored
19
.github/workflows/ci-release.yml
vendored
@@ -2,9 +2,9 @@ name: CI and Release
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
branches: ["**"]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
branches: ["**"]
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -127,16 +127,23 @@ jobs:
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=${{ secrets.TEST_GOOGLE_DRIVE_CLIENT_SECRET }}
|
||||
TEST_GOOGLE_DRIVE_TOKEN_JSON=${{ secrets.TEST_GOOGLE_DRIVE_TOKEN_JSON }}
|
||||
# testing DBs
|
||||
TEST_POSTGRES_12_PORT=5000
|
||||
TEST_POSTGRES_13_PORT=5001
|
||||
TEST_POSTGRES_14_PORT=5002
|
||||
TEST_POSTGRES_15_PORT=5003
|
||||
TEST_POSTGRES_16_PORT=5004
|
||||
TEST_POSTGRES_17_PORT=5005
|
||||
TEST_POSTGRES_18_PORT=5006
|
||||
# testing S3
|
||||
TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing Azure Blob
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=5006
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
|
||||
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -150,6 +157,7 @@ jobs:
|
||||
timeout 60 bash -c 'until docker exec dev-db pg_isready -h localhost -p 5437 -U postgres; do sleep 2; done'
|
||||
|
||||
# Wait for test databases
|
||||
timeout 60 bash -c 'until nc -z localhost 5000; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5001; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5002; do sleep 2; done'
|
||||
timeout 60 bash -c 'until nc -z localhost 5003; do sleep 2; done'
|
||||
@@ -159,6 +167,9 @@ jobs:
|
||||
# Wait for MinIO
|
||||
timeout 60 bash -c 'until nc -z localhost 9000; do sleep 2; done'
|
||||
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
|
||||
- name: Create data and temp directories
|
||||
run: |
|
||||
# Create directories that are used for backups and restore
|
||||
@@ -181,7 +192,7 @@ jobs:
|
||||
- name: Run Go tests
|
||||
run: |
|
||||
cd backend
|
||||
go test ./internal/...
|
||||
go test -p=1 -count=1 -failfast ./internal/...
|
||||
|
||||
- name: Stop test containers
|
||||
if: always()
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -3,4 +3,6 @@ postgresus-data/
|
||||
.env
|
||||
pgdata/
|
||||
docker-compose.yml
|
||||
node_modules/
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
29
.pre-commit-config.yaml
Normal file
29
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,29 @@
|
||||
# Pre-commit configuration
|
||||
# See https://pre-commit.com for more information
|
||||
repos:
|
||||
# Frontend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: frontend-format
|
||||
name: Frontend Format (Prettier)
|
||||
entry: powershell -Command "cd frontend; npm run format"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx|json|css|md)$
|
||||
pass_filenames: false
|
||||
|
||||
- id: frontend-lint
|
||||
name: Frontend Lint (ESLint)
|
||||
entry: powershell -Command "cd frontend; npm run lint"
|
||||
language: system
|
||||
files: ^frontend/.*\.(ts|tsx|js|jsx)$
|
||||
pass_filenames: false
|
||||
|
||||
# Backend checks
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: backend-format-and-lint
|
||||
name: Backend Format & Lint (golangci-lint)
|
||||
entry: powershell -Command "cd backend; golangci-lint fmt; golangci-lint run"
|
||||
language: system
|
||||
files: ^backend/.*\.go$
|
||||
pass_filenames: false
|
||||
75
Dockerfile
75
Dockerfile
@@ -13,18 +13,30 @@ COPY frontend/ ./
|
||||
|
||||
# Copy .env file (with fallback to .env.production.example)
|
||||
RUN if [ ! -f .env ]; then \
|
||||
if [ -f .env.production.example ]; then \
|
||||
cp .env.production.example .env; \
|
||||
fi; \
|
||||
fi
|
||||
if [ -f .env.production.example ]; then \
|
||||
cp .env.production.example .env; \
|
||||
fi; \
|
||||
fi
|
||||
|
||||
RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.23.3 AS backend-build
|
||||
|
||||
# Install Go public tools needed in runtime
|
||||
RUN curl -fsSL https://raw.githubusercontent.com/pressly/goose/master/install.sh | sh
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
|
||||
# Install Go public tools needed in runtime. Use `go build` for goose so the
|
||||
# binary is compiled for the target architecture instead of downloading a
|
||||
# prebuilt binary which may have the wrong architecture (causes exec format
|
||||
# errors on ARM).
|
||||
RUN git clone --depth 1 --branch v3.24.3 https://github.com/pressly/goose.git /tmp/goose && \
|
||||
cd /tmp/goose/cmd/goose && \
|
||||
GOOS=${TARGETOS:-linux} GOARCH=${TARGETARCH:-amd64} \
|
||||
go build -o /usr/local/bin/goose . && \
|
||||
rm -rf /tmp/goose
|
||||
RUN go install github.com/swaggo/swag/cmd/swag@v1.16.4
|
||||
|
||||
# Set working directory
|
||||
@@ -49,35 +61,38 @@ ARG TARGETOS
|
||||
ARG TARGETARCH
|
||||
ARG TARGETVARIANT
|
||||
RUN CGO_ENABLED=0 \
|
||||
GOOS=$TARGETOS \
|
||||
GOARCH=$TARGETARCH \
|
||||
go build -o /app/main ./cmd/main.go
|
||||
GOOS=$TARGETOS \
|
||||
GOARCH=$TARGETARCH \
|
||||
go build -o /app/main ./cmd/main.go
|
||||
|
||||
|
||||
# ========= RUNTIME =========
|
||||
FROM --platform=$TARGETPLATFORM debian:bookworm-slim
|
||||
FROM debian:bookworm-slim
|
||||
|
||||
# Add version metadata to runtime image
|
||||
ARG APP_VERSION=dev
|
||||
LABEL org.opencontainers.image.version=$APP_VERSION
|
||||
ENV APP_VERSION=$APP_VERSION
|
||||
|
||||
# Install PostgreSQL server and client tools (versions 13-17)
|
||||
# Set production mode for Docker containers
|
||||
ENV ENV_MODE=production
|
||||
|
||||
# Install PostgreSQL server and client tools (versions 12-18)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
wget ca-certificates gnupg lsb-release sudo gosu && \
|
||||
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
|
||||
echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \
|
||||
> /etc/apt/sources.list.d/pgdg.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
|
||||
postgresql-client-16 postgresql-client-17 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
wget ca-certificates gnupg lsb-release sudo gosu && \
|
||||
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
|
||||
echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \
|
||||
> /etc/apt/sources.list.d/pgdg.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
|
||||
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create postgres user and set up directories
|
||||
RUN useradd -m -s /bin/bash postgres || true && \
|
||||
mkdir -p /postgresus-data/pgdata && \
|
||||
chown -R postgres:postgres /postgresus-data/pgdata
|
||||
mkdir -p /postgresus-data/pgdata && \
|
||||
chown -R postgres:postgres /postgresus-data/pgdata
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -96,10 +111,10 @@ COPY --from=backend-build /app/ui/build ./ui/build
|
||||
# Copy .env file (with fallback to .env.production.example)
|
||||
COPY backend/.env* /app/
|
||||
RUN if [ ! -f /app/.env ]; then \
|
||||
if [ -f /app/.env.production.example ]; then \
|
||||
cp /app/.env.production.example /app/.env; \
|
||||
fi; \
|
||||
fi
|
||||
if [ -f /app/.env.production.example ]; then \
|
||||
cp /app/.env.production.example /app/.env; \
|
||||
fi; \
|
||||
fi
|
||||
|
||||
# Create startup script
|
||||
COPY <<EOF /app/start.sh
|
||||
@@ -151,8 +166,10 @@ done
|
||||
echo "Setting up database and user..."
|
||||
gosu postgres \$PG_BIN/psql -p 5437 -h localhost -d postgres << 'SQL'
|
||||
ALTER USER postgres WITH PASSWORD 'Q1234567';
|
||||
CREATE DATABASE "postgresus" OWNER postgres;
|
||||
\q
|
||||
SELECT 'CREATE DATABASE postgresus OWNER postgres'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = 'postgresus')
|
||||
\\gexec
|
||||
\\q
|
||||
SQL
|
||||
|
||||
# Start the main application
|
||||
@@ -168,4 +185,4 @@ EXPOSE 4005
|
||||
VOLUME ["/postgresus-data"]
|
||||
|
||||
ENTRYPOINT ["/app/start.sh"]
|
||||
CMD []
|
||||
CMD []
|
||||
215
LICENSE
215
LICENSE
@@ -1,21 +1,202 @@
|
||||
MIT License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Copyright (c) 2025 Postgresus
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
1. Definitions.
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"Licensor" shall mean the copyright owner or entity granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(which shall not include communications that are solely written
|
||||
by You).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based upon (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and derivative works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control
|
||||
systems, and issue tracking systems that are managed by, or on behalf
|
||||
of, the Licensor for the purpose of discussing and improving the Work,
|
||||
but excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to use, reproduce, modify, distribute, prepare
|
||||
Derivative Works of, and publicly display, publicly perform,
|
||||
sublicense, and distribute the Work and such Derivative Works in
|
||||
Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright notice to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. When redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "license" line as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2025 Postgresus
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
52
README.md
52
README.md
@@ -1,15 +1,15 @@
|
||||
<div align="center">
|
||||
<img src="assets/logo.svg" style="margin-bottom: 20px;" alt="Postgresus Logo" width="250"/>
|
||||
|
||||
<h3>PostgreSQL monitoring and backup</h3>
|
||||
<p>Free, open source and self-hosted solution for automated PostgreSQL monitoring and backups. With multiple storage options and notifications</p>
|
||||
<h3>PostgreSQL backup</h3>
|
||||
<p>Free, open source and self-hosted solution for automated PostgreSQL backups. With multiple storage options and notifications</p>
|
||||
|
||||
<!-- Badges -->
|
||||
[](LICENSE)
|
||||
[](LICENSE)
|
||||
[](https://hub.docker.com/r/rostislavdugin/postgresus)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://www.postgresql.org/)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
[](https://github.com/RostislavDugin/postgresus)
|
||||
|
||||
@@ -40,13 +40,13 @@
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗄️ **Multiple Storage Destinations**
|
||||
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
|
||||
- **Secure**: All data stays under your control
|
||||
|
||||
### 📱 **Smart Notifications**
|
||||
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
|
||||
|
||||
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
|
||||
- **Real-time updates**: Success and failure notifications
|
||||
@@ -54,23 +54,31 @@
|
||||
|
||||
### 🐘 **PostgreSQL Support**
|
||||
|
||||
- **Multiple versions**: PostgreSQL 13, 14, 15, 16 and 17
|
||||
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
|
||||
- **SSL support**: Secure connections available
|
||||
- **Easy restoration**: One-click restore from any backup
|
||||
|
||||
### 🔒 **Enterprise-grade security** <a href="https://postgresus.com/security">(docs)</a>
|
||||
|
||||
- **AES-256-GCM encryption**: Enterprise-grade protection for backup files
|
||||
- **Zero-trust storage**: Backups are encrypted and they are useless to attackers, so you can keep them in shared storages like S3, Azure Blob Storage, etc.
|
||||
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
|
||||
- **Read-only user**: Postgresus uses by default a read-only user for backups and never stores anything that can change your data
|
||||
|
||||
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
- **Access management**: Control who can view or manage specific databases with role-based permissions
|
||||
- **Audit logs**: Track all system activities and changes made by users
|
||||
- **User roles**: Assign viewer, member, admin or owner roles within workspaces
|
||||
|
||||
### 🐳 **Self-Hosted & Secure**
|
||||
|
||||
- **Docker-based**: Easy deployment and management
|
||||
- **Privacy-first**: All your data stays on your infrastructure
|
||||
- **Open source**: MIT licensed, inspect every line of code
|
||||
- **Open source**: Apache 2.0 licensed, inspect every line of code
|
||||
|
||||
### 📊 **Monitoring & Insights**
|
||||
|
||||
- **Real-time metrics**: Track database health
|
||||
- **Historical data**: View trends and patterns over time
|
||||
- **Alert system**: Get notified when issues are detected
|
||||
|
||||
### 📦 Installation
|
||||
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
|
||||
|
||||
You have three ways to install Postgresus:
|
||||
|
||||
@@ -124,8 +132,6 @@ This single command will:
|
||||
Create a `docker-compose.yml` file with the following configuration:
|
||||
|
||||
```yaml
|
||||
version: "3"
|
||||
|
||||
services:
|
||||
postgresus:
|
||||
container_name: postgresus
|
||||
@@ -155,22 +161,24 @@ docker compose up -d
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting Admin Password
|
||||
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
|
||||
|
||||
If you need to reset the admin password, you can use the built-in password reset command:
|
||||
If you need to reset the password, you can use the built-in password reset command:
|
||||
|
||||
```bash
|
||||
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123"
|
||||
docker exec -it postgresus ./main --new-password="YourNewSecurePassword123" --email="admin"
|
||||
```
|
||||
|
||||
Replace `admin` with the actual email address of the user whose password you want to reset.
|
||||
|
||||
---
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read [contributing guide](contribute/readme.md) for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
Contributions are welcome! Read <a href="https://postgresus.com/contributing">contributing guide</a> for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
1607
assets/dashboard.svg
1607
assets/dashboard.svg
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 791 KiB After Width: | Height: | Size: 913 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 12 KiB After Width: | Height: | Size: 34 KiB |
@@ -1,14 +1,152 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
Always place private methods to the bottom of file
|
||||
|
||||
Code should look like:
|
||||
**This rule applies to ALL Go files including tests, services, controllers, repositories, etc.**
|
||||
|
||||
type SomeService struct {
|
||||
func PublicMethod(...) ...
|
||||
In Go, exported (public) functions/methods start with uppercase letters, while unexported (private) ones start with lowercase letters.
|
||||
|
||||
func privateMethod(...) ...
|
||||
}
|
||||
## Structure Order:
|
||||
|
||||
1. Type definitions and constants
|
||||
2. Public methods/functions (uppercase)
|
||||
3. Private methods/functions (lowercase)
|
||||
|
||||
## Examples:
|
||||
|
||||
### Service with methods:
|
||||
|
||||
```go
|
||||
type UserService struct {
|
||||
repository *UserRepository
|
||||
}
|
||||
|
||||
// Public methods first
|
||||
func (s *UserService) CreateUser(user *User) error {
|
||||
if err := s.validateUser(user); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.repository.Save(user)
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(id uuid.UUID) (*User, error) {
|
||||
return s.repository.FindByID(id)
|
||||
}
|
||||
|
||||
// Private methods at the bottom
|
||||
func (s *UserService) validateUser(user *User) error {
|
||||
if user.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
```
|
||||
|
||||
### Package-level functions:
|
||||
|
||||
```go
|
||||
package utils
|
||||
|
||||
// Public functions first
|
||||
func ProcessData(data []byte) (Result, error) {
|
||||
cleaned := sanitizeInput(data)
|
||||
return parseData(cleaned)
|
||||
}
|
||||
|
||||
func ValidateInput(input string) bool {
|
||||
return isValidFormat(input) && checkLength(input)
|
||||
}
|
||||
|
||||
// Private functions at the bottom
|
||||
func sanitizeInput(data []byte) []byte {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func parseData(data []byte) (Result, error) {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func isValidFormat(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
|
||||
func checkLength(input string) bool {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Test files:
|
||||
|
||||
```go
|
||||
package user_test
|
||||
|
||||
// Public test functions first
|
||||
func Test_CreateUser_ValidInput_UserCreated(t *testing.T) {
|
||||
user := createTestUser()
|
||||
result, err := service.CreateUser(user)
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
|
||||
func Test_GetUser_ExistingUser_ReturnsUser(t *testing.T) {
|
||||
user := createTestUser()
|
||||
// test implementation
|
||||
}
|
||||
|
||||
// Private helper functions at the bottom
|
||||
func createTestUser() *User {
|
||||
return &User{
|
||||
Name: "Test User",
|
||||
Email: "test@example.com",
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestDatabase() *Database {
|
||||
// setup implementation
|
||||
}
|
||||
```
|
||||
|
||||
### Controller example:
|
||||
|
||||
```go
|
||||
type ProjectController struct {
|
||||
service *ProjectService
|
||||
}
|
||||
|
||||
// Public HTTP handlers first
|
||||
func (c *ProjectController) CreateProject(ctx *gin.Context) {
|
||||
var request CreateProjectRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
c.handleError(ctx, err)
|
||||
return
|
||||
}
|
||||
// handler logic
|
||||
}
|
||||
|
||||
func (c *ProjectController) GetProject(ctx *gin.Context) {
|
||||
projectID := c.extractProjectID(ctx)
|
||||
// handler logic
|
||||
}
|
||||
|
||||
// Private helper methods at the bottom
|
||||
func (c *ProjectController) handleError(ctx *gin.Context, err error) {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
}
|
||||
|
||||
func (c *ProjectController) extractProjectID(ctx *gin.Context) uuid.UUID {
|
||||
return uuid.MustParse(ctx.Param("projectId"))
|
||||
}
|
||||
```
|
||||
|
||||
## Key Points:
|
||||
|
||||
- **Exported/Public** = starts with uppercase letter (CreateUser, GetProject)
|
||||
- **Unexported/Private** = starts with lowercase letter (validateUser, handleError)
|
||||
- This improves code readability by showing the public API first
|
||||
- Private helpers are implementation details, so they go at the bottom
|
||||
- Apply this rule consistently across ALL Go files in the project
|
||||
|
||||
@@ -1,7 +1,45 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Do not write obvious comments.
|
||||
Write meaningful code, give meaningful names
|
||||
|
||||
## Comment Guidelines
|
||||
|
||||
1. **No obvious comments** - Don't state what the code already clearly shows
|
||||
2. **Functions and variables should have meaningful names** - Code should be self-documenting
|
||||
3. **Comments for unclear code only** - Only add comments when code logic isn't immediately clear
|
||||
|
||||
## Key Principles:
|
||||
|
||||
- **Code should tell a story** - Use descriptive variable and function names
|
||||
- **Comments explain WHY, not WHAT** - The code shows what happens, comments explain business logic or complex decisions
|
||||
- **Prefer refactoring over commenting** - If code needs explaining, consider making it clearer instead
|
||||
- **API documentation is required** - Swagger comments for all HTTP endpoints are mandatory
|
||||
- **Complex algorithms deserve comments** - Mathematical formulas, business rules, or non-obvious optimizations
|
||||
|
||||
Example of useless comment:
|
||||
|
||||
1.
|
||||
|
||||
```sql
|
||||
// Create projects table
|
||||
CREATE TABLE projects (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
```
|
||||
|
||||
2.
|
||||
|
||||
```go
|
||||
// Create test project
|
||||
project := CreateTestProject(projectName, user, router)
|
||||
```
|
||||
|
||||
3.
|
||||
|
||||
```go
|
||||
// CreateValidLogItems creates valid log items for testing
|
||||
func CreateValidLogItems(count int, uniqueID string) []logs_receiving.LogItemRequestDTO {
|
||||
```
|
||||
|
||||
@@ -1,55 +1,133 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
1. When we write controller:
|
||||
|
||||
- we combine all routes to single controller
|
||||
- names them as .WhatWeDo (not "handlers") concept
|
||||
|
||||
2. We use gin and *gin.Context for all routes.
|
||||
Example:
|
||||
2. We use gin and \*gin.Context for all routes.
|
||||
Example:
|
||||
|
||||
func (c *TasksController) GetAvailableTasks(ctx *gin.Context) ...
|
||||
|
||||
3. We document all routes with Swagger in the following format:
|
||||
|
||||
// SignIn
|
||||
// @Summary Authenticate a user
|
||||
// @Description Authenticate a user with email and password
|
||||
// @Tags users
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body SignInRequest true "User signin data"
|
||||
// @Success 200 {object} SignInResponse
|
||||
// @Failure 400
|
||||
// @Router /users/signin [post]
|
||||
package audit_logs
|
||||
|
||||
Do not forget to write docs.
|
||||
You can avoid description if it is useless.
|
||||
Specify particular acceping \ producing models
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
4. All controllers should have RegisterRoutes method which receives
|
||||
RouterGroup (always put this routes on the top of file under controller definition)
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
Example:
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
|
||||
func (c *OrderController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/bots/users/orders/generate", c.GenerateOrder)
|
||||
router.POST("/bots/users/orders/generate-by-admin", c.GenerateOrderByAdmin)
|
||||
router.GET("/bots/users/orders/mark-as-paid-by-admin", c.MarkOrderAsPaidByAdmin)
|
||||
router.GET("/bots/users/orders/payments-by-bot", c.GetOrderPaymentsByBot)
|
||||
router.GET("/bots/users/orders/payments-by-user", c.GetOrderPaymentsByUser)
|
||||
router.GET("/bots/users/orders/orders-by-user-for-admin", c.GetOrdersByUserForAdmin)
|
||||
router.POST("/bots/users/orders/orders-by-user-for-user", c.GetOrdersByUserForUser)
|
||||
router.POST("/bots/users/orders/order", c.GetOrder)
|
||||
router.POST("/bots/users/orders/cancel-subscription-by-user", c.CancelSubscriptionByUser)
|
||||
router.GET("/bots/users/orders/cancel-subscription-by-admin", c.CancelSubscriptionByAdmin)
|
||||
router.GET(
|
||||
"/bots/users/orders/cancel-subscriptions-by-payment-option",
|
||||
c.CancelSubscriptionsByPaymentOption,
|
||||
)
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService \*AuditLogService
|
||||
}
|
||||
|
||||
5. Check that use use valid .Query("param") and .Param("param") methods.
|
||||
If route does not have param - use .Query("query")
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(\*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
|
||||
}
|
||||
|
||||
671
backend/.cursor/rules/crud.mdc
Normal file
671
backend/.cursor/rules/crud.mdc
Normal file
@@ -0,0 +1,671 @@
|
||||
---
|
||||
alwaysApply: false
|
||||
---
|
||||
|
||||
This is example of CRUD:
|
||||
|
||||
------ backend/internal/features/audit_logs/controller.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/controller_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
users_middleware "postgresus/internal/features/users/middleware"
|
||||
users_services "postgresus/internal/features/users/services"
|
||||
users_testing "postgresus/internal/features/users/testing"
|
||||
"postgresus/internal/storage"
|
||||
test_utils "postgresus/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_AdminSucceedsAndMemberGetsForbidden(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs
|
||||
createAuditLog(service, "Test log with user", &adminUser.UserID, nil)
|
||||
createAuditLog(service, "Test log with project", nil, &projectID)
|
||||
createAuditLog(service, "Test log standalone", nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=10", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response.AuditLogs), 3)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(3))
|
||||
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log with user")
|
||||
assert.Contains(t, messages, "Test log with project")
|
||||
assert.Contains(t, messages, "Test log standalone")
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_PermissionsEnforcedCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
projectID := uuid.New()
|
||||
|
||||
// Create test logs for different users
|
||||
createAuditLog(service, "Test log user1 first", &user1.UserID, nil)
|
||||
createAuditLog(service, "Test log user1 second", &user1.UserID, &projectID)
|
||||
createAuditLog(service, "Test log user2 first", &user2.UserID, nil)
|
||||
createAuditLog(service, "Test log user2 second", &user2.UserID, &projectID)
|
||||
createAuditLog(service, "Test project log", nil, &projectID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=10", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
assert.Equal(t, 2, len(user1Response.AuditLogs))
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test log user1 first")
|
||||
assert.Contains(t, messages, "Test log user1 second")
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
assert.Equal(t, 2, len(ownLogsResponse.AuditLogs))
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_FilterAuditLogsByTime_ReturnsOnlyLogsBeforeDate(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
db := storage.GetDb()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Create logs with different timestamps
|
||||
createTimedLog(db, &adminUser.UserID, "Test old log", baseTime.Add(-2*time.Hour))
|
||||
createTimedLog(db, &adminUser.UserID, "Test recent log", baseTime.Add(-30*time.Minute))
|
||||
createAuditLog(service, "Test current log", &adminUser.UserID, nil)
|
||||
|
||||
// Test filtering - get logs before 1 hour ago
|
||||
beforeTime := baseTime.Add(-1 * time.Hour)
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/global?beforeDate=%s", beforeTime.Format(time.RFC3339)),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &filteredResponse)
|
||||
|
||||
// Verify only old log is returned
|
||||
messages := extractMessages(filteredResponse.AuditLogs)
|
||||
assert.Contains(t, messages, "Test old log")
|
||||
assert.NotContains(t, messages, "Test recent log")
|
||||
assert.NotContains(t, messages, "Test current log")
|
||||
|
||||
// Test without filter - should get all logs
|
||||
var allResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &allResponse)
|
||||
assert.GreaterOrEqual(t, len(allResponse.AuditLogs), 3)
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/di.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "postgresus/internal/features/users/services"
|
||||
"postgresus/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/dto.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import "time"
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLog `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/models.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
ProjectID *uuid.UUID `json:"projectId" gorm:"column:project_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/repository.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"postgresus/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(limit, offset int, beforeDate *time.Time) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByProject(
|
||||
projectID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLog, error) {
|
||||
var auditLogs []*AuditLog
|
||||
|
||||
query := storage.GetDb().
|
||||
Where("project_id = ?", projectID).
|
||||
Order("created_at DESC")
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
user_models "postgresus/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
projectID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(targetUserID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetProjectAuditLogs(
|
||||
projectID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByProject(projectID, limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
------ backend/internal/features/audit_logs/service_test.go ------
|
||||
|
||||
```
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus/internal/features/users/enums"
|
||||
users_testing "postgresus/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_ProjectSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
project1ID, project2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for projects
|
||||
createAuditLog(service, "Test project1 log first", &user1.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project1 log second", &user2.UserID, &project1ID)
|
||||
createAuditLog(service, "Test project2 log first", &user1.UserID, &project2ID)
|
||||
createAuditLog(service, "Test project2 log second", &user2.UserID, &project2ID)
|
||||
createAuditLog(service, "Test no project log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test project 1 logs
|
||||
project1Response, err := service.GetProjectAuditLogs(project1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(project1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test project1 log first")
|
||||
assert.Contains(t, messages, "Test project1 log second")
|
||||
for _, log := range project1Response.AuditLogs {
|
||||
assert.Equal(t, &project1ID, log.ProjectID)
|
||||
}
|
||||
|
||||
// Test project 2 logs
|
||||
project2Response, err := service.GetProjectAuditLogs(project2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(project2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(project2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test project2 log first")
|
||||
assert.Contains(t, messages2, "Test project2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetProjectAuditLogs(project1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, projectID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, projectID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLog) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func createTimedLog(db *gorm.DB, userID *uuid.UUID, message string, createdAt time.Time) {
|
||||
log := &AuditLog{
|
||||
ID: uuid.New(),
|
||||
UserID: userID,
|
||||
Message: message,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
db.Create(log)
|
||||
}
|
||||
|
||||
```
|
||||
12
backend/.cursor/rules/refactor.mdc
Normal file
12
backend/.cursor/rules/refactor.mdc
Normal file
@@ -0,0 +1,12 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
|
||||
When applying changes, do not forget to refactor old code.
|
||||
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder).
|
||||
@@ -1,12 +1,147 @@
|
||||
---
|
||||
description:
|
||||
globs:
|
||||
description:
|
||||
globs:
|
||||
alwaysApply: true
|
||||
---
|
||||
Write tests names in the format:
|
||||
|
||||
Test_WhatWeDo_WhatWeExpect
|
||||
After writing tests, always launch them and verify that they pass.
|
||||
|
||||
Example:
|
||||
- Test_TestConnection_ConnectionSucceeds
|
||||
- Test_SaveNewStorage_StorageReturnedViaGet
|
||||
## Test Naming Format
|
||||
|
||||
Use these naming patterns:
|
||||
|
||||
- `Test_WhatWeDo_WhatWeExpect`
|
||||
- `Test_WhatWeDo_WhichConditions_WhatWeExpect`
|
||||
|
||||
## Examples from Real Codebase:
|
||||
|
||||
- `Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated`
|
||||
- `Test_UpdateProject_WhenUserIsProjectAdmin_ProjectUpdated`
|
||||
- `Test_DeleteApiKey_WhenUserIsProjectMember_ReturnsForbidden`
|
||||
- `Test_GetProjectAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly`
|
||||
- `Test_ProjectLifecycleE2E_CompletesSuccessfully`
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
**Prefer Controllers Over Unit Tests:**
|
||||
|
||||
- Test through HTTP endpoints via controllers whenever possible
|
||||
- Avoid testing repositories, services in isolation - test via API instead
|
||||
- Only use unit tests for complex model logic when no API exists
|
||||
- Name test files `controller_test.go` or `service_test.go`, not `integration_test.go`
|
||||
|
||||
**Extract Common Logic to Testing Utilities:**
|
||||
|
||||
- Create `testing.go` or `testing/testing.go` files for shared test utilities
|
||||
- Extract router creation, user setup, models creation helpers (in API, not just structs creation)
|
||||
- Reuse common patterns across different test files
|
||||
|
||||
**Refactor Existing Tests:**
|
||||
|
||||
- When working with existing tests, always look for opportunities to refactor and improve
|
||||
- Extract repetitive setup code to common utilities
|
||||
- Simplify complex tests by breaking them into smaller, focused tests
|
||||
- Replace inline test data creation with reusable helper functions
|
||||
- Consolidate similar test patterns across different test files
|
||||
- Make tests more readable and maintainable for other developers
|
||||
|
||||
## Testing Utilities Structure
|
||||
|
||||
**Create `testing.go` or `testing/testing.go` files with common utilities:**
|
||||
|
||||
```go
|
||||
package projects_testing
|
||||
|
||||
// CreateTestRouter creates unified router for all controllers
|
||||
func CreateTestRouter(controllers ...ControllerInterface) *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
|
||||
for _, controller := range controllers {
|
||||
if routerGroup, ok := protected.(*gin.RouterGroup); ok {
|
||||
controller.RegisterRoutes(routerGroup)
|
||||
}
|
||||
}
|
||||
return router
|
||||
}
|
||||
|
||||
// CreateTestProjectViaAPI creates project through HTTP API
|
||||
func CreateTestProjectViaAPI(name string, owner *users_dto.SignInResponseDTO, router *gin.Engine) (*projects_models.Project, string) {
|
||||
request := projects_dto.CreateProjectRequestDTO{Name: name}
|
||||
w := MakeAPIRequest(router, "POST", "/api/v1/projects", "Bearer "+owner.Token, request)
|
||||
// Handle response...
|
||||
return project, owner.Token
|
||||
}
|
||||
|
||||
// AddMemberToProject adds member via API call
|
||||
func AddMemberToProject(project *projects_models.Project, member *users_dto.SignInResponseDTO, role users_enums.ProjectRole, ownerToken string, router *gin.Engine) {
|
||||
// Implementation...
|
||||
}
|
||||
```
|
||||
|
||||
## Controller Test Examples
|
||||
|
||||
**Permission-based testing:**
|
||||
|
||||
```go
|
||||
func Test_CreateApiKey_WhenUserIsProjectOwner_ApiKeyCreated(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project, _ := projects_testing.CreateTestProjectViaAPI("Test Project", owner, router)
|
||||
|
||||
request := CreateApiKeyRequestDTO{Name: "Test API Key"}
|
||||
var response ApiKey
|
||||
test_utils.MakePostRequestAndUnmarshal(t, router, "/api/v1/projects/api-keys/"+project.ID.String(), "Bearer "+owner.Token, request, http.StatusOK, &response)
|
||||
|
||||
assert.Equal(t, "Test API Key", response.Name)
|
||||
assert.NotEmpty(t, response.Token)
|
||||
}
|
||||
```
|
||||
|
||||
**Cross-project security testing:**
|
||||
|
||||
```go
|
||||
func Test_UpdateApiKey_WithApiKeyFromDifferentProject_ReturnsBadRequest(t *testing.T) {
|
||||
router := CreateApiKeyTestRouter(GetProjectController(), GetMembershipController())
|
||||
owner1 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
owner2 := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project1, _ := projects_testing.CreateTestProjectViaAPI("Project 1", owner1, router)
|
||||
project2, _ := projects_testing.CreateTestProjectViaAPI("Project 2", owner2, router)
|
||||
|
||||
apiKey := CreateTestApiKey("Cross Project Key", project1.ID, owner1.Token, router)
|
||||
|
||||
// Try to update via different project endpoint
|
||||
request := UpdateApiKeyRequestDTO{Name: &"Hacked Key"}
|
||||
resp := test_utils.MakePutRequest(t, router, "/api/v1/projects/api-keys/"+project2.ID.String()+"/"+apiKey.ID.String(), "Bearer "+owner2.Token, request, http.StatusBadRequest)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "API key does not belong to this project")
|
||||
}
|
||||
```
|
||||
|
||||
**E2E lifecycle testing:**
|
||||
|
||||
```go
|
||||
func Test_ProjectLifecycleE2E_CompletesSuccessfully(t *testing.T) {
|
||||
router := projects_testing.CreateTestRouter(GetProjectController(), GetMembershipController())
|
||||
|
||||
// 1. Create project
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
project := projects_testing.CreateTestProject("E2E Project", owner, router)
|
||||
|
||||
// 2. Add member
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
projects_testing.AddMemberToProject(project, member, users_enums.ProjectRoleMember, owner.Token, router)
|
||||
|
||||
// 3. Promote to admin
|
||||
projects_testing.ChangeMemberRole(project, member.UserID, users_enums.ProjectRoleAdmin, owner.Token, router)
|
||||
|
||||
// 4. Transfer ownership
|
||||
projects_testing.TransferProjectOwnership(project, member.UserID, owner.Token, router)
|
||||
|
||||
// 5. Verify new owner can manage project
|
||||
finalProject := projects_testing.GetProject(project.ID, member.Token, router)
|
||||
assert.Equal(t, project.ID, finalProject.ID)
|
||||
}
|
||||
```
|
||||
|
||||
@@ -17,13 +17,20 @@ TEST_GOOGLE_DRIVE_CLIENT_ID=
|
||||
TEST_GOOGLE_DRIVE_CLIENT_SECRET=
|
||||
TEST_GOOGLE_DRIVE_TOKEN_JSON="{\"access_token\":\"ya29..."
|
||||
# testing DBs
|
||||
TEST_POSTGRES_12_PORT=5000
|
||||
TEST_POSTGRES_13_PORT=5001
|
||||
TEST_POSTGRES_14_PORT=5002
|
||||
TEST_POSTGRES_15_PORT=5003
|
||||
TEST_POSTGRES_16_PORT=5004
|
||||
TEST_POSTGRES_17_PORT=5005
|
||||
TEST_POSTGRES_18_PORT=5006
|
||||
# testing S3
|
||||
TEST_MINIO_PORT=9000
|
||||
TEST_MINIO_CONSOLE_PORT=9001
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=5006
|
||||
TEST_NAS_PORT=7006
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=
|
||||
TEST_TELEGRAM_CHAT_ID=
|
||||
# testing Azure Blob Storage
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -12,4 +12,5 @@ swagger/swagger.yaml
|
||||
postgresus-backend.exe
|
||||
ui/build/*
|
||||
pgdata-for-restore/
|
||||
temp/
|
||||
temp/
|
||||
cmd.exe
|
||||
20
backend/Makefile
Normal file
20
backend/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
run:
|
||||
go run cmd/main.go
|
||||
|
||||
test:
|
||||
go test -p=1 -count=1 -failfast .\internal\...
|
||||
|
||||
lint:
|
||||
golangci-lint fmt && golangci-lint run
|
||||
|
||||
migration-create:
|
||||
goose create $(name) sql
|
||||
|
||||
migration-up:
|
||||
goose up
|
||||
|
||||
migration-down:
|
||||
goose down
|
||||
|
||||
swagger:
|
||||
swag init -g ./cmd/main.go -o swagger
|
||||
@@ -9,44 +9,39 @@ instead of postgresus-db from docker-compose.yml in the root folder.
|
||||
|
||||
# Run
|
||||
|
||||
To build:
|
||||
|
||||
> go build /cmd/main.go
|
||||
|
||||
To run:
|
||||
|
||||
> go run /cmd/main.go
|
||||
> make run
|
||||
|
||||
To run tests:
|
||||
|
||||
> go test ./internal/...
|
||||
> make test
|
||||
|
||||
Before commit (make sure `golangci-lint` is installed):
|
||||
|
||||
> golangci-lint fmt
|
||||
> golangci-lint run
|
||||
> make lint
|
||||
|
||||
# Migrations
|
||||
|
||||
To create migration:
|
||||
|
||||
> goose create MIGRATION_NAME sql
|
||||
> make migration-create name=MIGRATION_NAME
|
||||
|
||||
To run migrations:
|
||||
|
||||
> goose up
|
||||
> make migration-up
|
||||
|
||||
If latest migration failed:
|
||||
|
||||
To rollback on migration:
|
||||
|
||||
> goose down
|
||||
> make migration-down
|
||||
|
||||
# Swagger
|
||||
|
||||
To generate swagger docs:
|
||||
|
||||
> swag init -g .\cmd\main.go -o swagger
|
||||
> make swagger
|
||||
|
||||
Swagger URL is:
|
||||
|
||||
|
||||
@@ -13,18 +13,22 @@ import (
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/downdetect"
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/disk"
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
healthcheck_attempt "postgresus-backend/internal/features/healthcheck/attempt"
|
||||
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/restores"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
system_healthcheck "postgresus-backend/internal/features/system/healthcheck"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_controllers "postgresus-backend/internal/features/users/controllers"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
env_utils "postgresus-backend/internal/util/env"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
@@ -61,13 +65,20 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Handle password reset if flag is provided
|
||||
newPassword := flag.String("new-password", "", "Set a new password for the user")
|
||||
flag.Parse()
|
||||
if *newPassword != "" {
|
||||
resetPassword(*newPassword, log)
|
||||
err = secrets.GetSecretKeyService().MigrateKeyFromDbToFileIfExist()
|
||||
if err != nil {
|
||||
log.Error("Failed to migrate secret key from database to file", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = users_services.GetUserService().CreateInitialAdmin()
|
||||
if err != nil {
|
||||
log.Error("Failed to create initial admin", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
handlePasswordReset(log)
|
||||
|
||||
go generateSwaggerDocs(log)
|
||||
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -91,11 +102,33 @@ func main() {
|
||||
startServerWithGracefulShutdown(log, ginApp)
|
||||
}
|
||||
|
||||
func resetPassword(newPassword string, log *slog.Logger) {
|
||||
func handlePasswordReset(log *slog.Logger) {
|
||||
audit_logs.SetupDependencies()
|
||||
|
||||
newPassword := flag.String("new-password", "", "Set a new password for the user")
|
||||
email := flag.String("email", "", "Email of the user to reset password")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if *newPassword == "" {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Found reset password command - reseting password...")
|
||||
|
||||
if *email == "" {
|
||||
log.Info("No email provided, please provide an email via --email=\"some@email.com\" flag")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
resetPassword(*email, *newPassword, log)
|
||||
}
|
||||
|
||||
func resetPassword(email string, newPassword string, log *slog.Logger) {
|
||||
log.Info("Resetting password...")
|
||||
|
||||
userService := users.GetUserService()
|
||||
err := userService.ChangePassword(newPassword)
|
||||
userService := users_services.GetUserService()
|
||||
err := userService.ChangeUserPasswordByEmail(email, newPassword)
|
||||
if err != nil {
|
||||
log.Error("Failed to reset password", "error", err)
|
||||
os.Exit(1)
|
||||
@@ -146,38 +179,44 @@ func setUpRoutes(r *gin.Engine) {
|
||||
// Mount Swagger UI
|
||||
v1.GET("/docs/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
downdetectContoller := downdetect.GetDowndetectController()
|
||||
userController := users.GetUserController()
|
||||
notifierController := notifiers.GetNotifierController()
|
||||
storageController := storages.GetStorageController()
|
||||
databaseController := databases.GetDatabaseController()
|
||||
backupController := backups.GetBackupController()
|
||||
restoreController := restores.GetRestoreController()
|
||||
healthcheckController := system_healthcheck.GetHealthcheckController()
|
||||
healthcheckConfigController := healthcheck_config.GetHealthcheckConfigController()
|
||||
healthcheckAttemptController := healthcheck_attempt.GetHealthcheckAttemptController()
|
||||
diskController := disk.GetDiskController()
|
||||
backupConfigController := backups_config.GetBackupConfigController()
|
||||
|
||||
downdetectContoller.RegisterRoutes(v1)
|
||||
// Public routes (only user auth routes and healthcheck should be public)
|
||||
userController := users_controllers.GetUserController()
|
||||
userController.RegisterRoutes(v1)
|
||||
notifierController.RegisterRoutes(v1)
|
||||
storageController.RegisterRoutes(v1)
|
||||
databaseController.RegisterRoutes(v1)
|
||||
backupController.RegisterRoutes(v1)
|
||||
restoreController.RegisterRoutes(v1)
|
||||
healthcheckController.RegisterRoutes(v1)
|
||||
diskController.RegisterRoutes(v1)
|
||||
healthcheckConfigController.RegisterRoutes(v1)
|
||||
healthcheckAttemptController.RegisterRoutes(v1)
|
||||
backupConfigController.RegisterRoutes(v1)
|
||||
system_healthcheck.GetHealthcheckController().RegisterRoutes(v1)
|
||||
|
||||
// Setup auth middleware
|
||||
userService := users_services.GetUserService()
|
||||
authMiddleware := users_middleware.AuthMiddleware(userService)
|
||||
|
||||
// Protected routes
|
||||
protected := v1.Group("")
|
||||
protected.Use(authMiddleware)
|
||||
|
||||
userController.RegisterProtectedRoutes(protected)
|
||||
workspaces_controllers.GetWorkspaceController().RegisterRoutes(protected)
|
||||
workspaces_controllers.GetMembershipController().RegisterRoutes(protected)
|
||||
disk.GetDiskController().RegisterRoutes(protected)
|
||||
notifiers.GetNotifierController().RegisterRoutes(protected)
|
||||
storages.GetStorageController().RegisterRoutes(protected)
|
||||
databases.GetDatabaseController().RegisterRoutes(protected)
|
||||
backups.GetBackupController().RegisterRoutes(protected)
|
||||
restores.GetRestoreController().RegisterRoutes(protected)
|
||||
healthcheck_config.GetHealthcheckConfigController().RegisterRoutes(protected)
|
||||
healthcheck_attempt.GetHealthcheckAttemptController().RegisterRoutes(protected)
|
||||
backups_config.GetBackupConfigController().RegisterRoutes(protected)
|
||||
audit_logs.GetAuditLogController().RegisterRoutes(protected)
|
||||
users_controllers.GetManagementController().RegisterRoutes(protected)
|
||||
users_controllers.GetSettingsController().RegisterRoutes(protected)
|
||||
}
|
||||
|
||||
func setUpDependencies() {
|
||||
backups.SetupDependencies()
|
||||
databases.SetupDependencies()
|
||||
backups.SetupDependencies()
|
||||
restores.SetupDependencies()
|
||||
healthcheck_config.SetupDependencies()
|
||||
audit_logs.SetupDependencies()
|
||||
notifiers.SetupDependencies()
|
||||
storages.SetupDependencies()
|
||||
}
|
||||
|
||||
func runBackgroundTasks(log *slog.Logger) {
|
||||
@@ -197,7 +236,7 @@ func runBackgroundTasks(log *slog.Logger) {
|
||||
})
|
||||
|
||||
go runWithPanicLogging(log, "healthcheck attempt background service", func() {
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().RunBackgroundTasks()
|
||||
healthcheck_attempt.GetHealthcheckAttemptBackgroundService().Run()
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -31,7 +31,26 @@ services:
|
||||
container_name: test-minio
|
||||
command: server /data --console-address ":9001"
|
||||
|
||||
# Test Azurite container
|
||||
test-azurite:
|
||||
image: mcr.microsoft.com/azure-storage/azurite
|
||||
ports:
|
||||
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
|
||||
container_name: test-azurite
|
||||
command: azurite-blob --blobHost 0.0.0.0
|
||||
|
||||
# Test PostgreSQL containers
|
||||
test-postgres-12:
|
||||
image: postgres:12
|
||||
ports:
|
||||
- "${TEST_POSTGRES_12_PORT}:5432"
|
||||
environment:
|
||||
- POSTGRES_DB=testdb
|
||||
- POSTGRES_USER=testuser
|
||||
- POSTGRES_PASSWORD=testpassword
|
||||
container_name: test-postgres-12
|
||||
shm_size: 1gb
|
||||
|
||||
test-postgres-13:
|
||||
image: postgres:13
|
||||
ports:
|
||||
@@ -87,6 +106,17 @@ services:
|
||||
container_name: test-postgres-17
|
||||
shm_size: 1gb
|
||||
|
||||
test-postgres-18:
|
||||
image: postgres:18
|
||||
ports:
|
||||
- "${TEST_POSTGRES_18_PORT}:5432"
|
||||
environment:
|
||||
- POSTGRES_DB=testdb
|
||||
- POSTGRES_USER=testuser
|
||||
- POSTGRES_PASSWORD=testpassword
|
||||
container_name: test-postgres-18
|
||||
shm_size: 1gb
|
||||
|
||||
# Test NAS server (Samba)
|
||||
test-nas:
|
||||
image: dperson/samba:latest
|
||||
|
||||
@@ -3,6 +3,8 @@ module postgresus-backend
|
||||
go 1.23.3
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3
|
||||
github.com/gin-contrib/cors v1.7.5
|
||||
github.com/gin-contrib/gzip v1.2.3
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
@@ -15,16 +17,18 @@ require (
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/minio/minio-go/v7 v7.0.92
|
||||
github.com/shirou/gopsutil/v4 v4.25.5
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
golang.org/x/crypto v0.39.0
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/time v0.12.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
|
||||
require github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.16.2 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
@@ -99,12 +103,12 @@ require (
|
||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.41.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.15.0 // indirect
|
||||
golang.org/x/sys v0.33.0 // indirect
|
||||
golang.org/x/text v0.26.0 // indirect
|
||||
golang.org/x/tools v0.33.0 // indirect
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
google.golang.org/api v0.239.0
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
|
||||
@@ -6,6 +6,18 @@ cloud.google.com/go/compute/metadata v0.7.0 h1:PBWF+iiAerVNe8UCHxdOt6eHLVc3ydFeO
|
||||
cloud.google.com/go/compute/metadata v0.7.0/go.mod h1:j5MvL9PprKL39t166CoB1uVHfQMs4tFQZZcKwksXUjo=
|
||||
filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA=
|
||||
filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 h1:KpMC6LFL7mqpExyMC9jVOYRiVhLmamjeZfRsUpB7l4s=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1 h1:/Zt+cDPnpC3OVDm/JKLOs7M2DKmLRIIp3XIx9pHHiig=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.8.1/go.mod h1:Ng3urmn6dYe8gnbCMoHHVl5APYz2txho3koEkV2o2HA=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3 h1:ZJJNFaQ86GVKQ9ehwqyAFE6pIfyicpuJ8IkVaPBc6/4=
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.6.3/go.mod h1:URuDvhmATVKqHBH9/0nOiNKk0+YcwfQ3WkK5PqHKxc8=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI=
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
|
||||
github.com/BurntSushi/toml v1.2.1/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
|
||||
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||
@@ -80,6 +92,8 @@ github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
@@ -131,6 +145,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
|
||||
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
@@ -159,6 +175,8 @@ github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c h1:dAMKvw0MlJT1GshSTtih8C2gDs04w8dReiOGXrGLNoY=
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c/go.mod h1:RqIHx9QI14HlwKwm98g9Re5prTQ6LdeRQn+gXJFxsJM=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
|
||||
@@ -180,8 +198,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
|
||||
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE=
|
||||
github.com/swaggo/files v1.0.1/go.mod h1:0qXmMNH6sXNf+73t65aKeB+ApmgxdnkQzVTAj2uaMUg=
|
||||
github.com/swaggo/gin-swagger v1.6.0 h1:y8sxvQ3E20/RCyrXeFfg60r6H0Z+SwpTjMYsMm+zy8M=
|
||||
@@ -216,25 +234,25 @@ golang.org/x/arch v0.17.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20200728195943-123391ffb6de/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM=
|
||||
golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U=
|
||||
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.41.0 h1:vBTly1HeNPEn3wtREYfy4GZ/NECgw2Cnl+nK6Nz3uvw=
|
||||
golang.org/x/net v0.41.0/go.mod h1:B/K4NNqkfmg07DQYrbwvSluqCJOOXwUjeb/5lOisjbA=
|
||||
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
|
||||
golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.15.0 h1:KWH3jNZsfyT6xfAfKiz6MRNmd46ByHDYaZ7KSkCtdW8=
|
||||
golang.org/x/sync v0.15.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -247,8 +265,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
|
||||
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
@@ -257,15 +275,15 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
|
||||
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
|
||||
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
|
||||
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
|
||||
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/api v0.239.0 h1:2hZKUnFZEy81eugPs4e2XzIJ5SOwQg0G82bpXD65Puo=
|
||||
google.golang.org/api v0.239.0/go.mod h1:cOVEm2TpdAGHL2z+UwyS+kmlGr3bVWQQ6sYEqkKje50=
|
||||
|
||||
@@ -26,23 +26,38 @@ type EnvVariables struct {
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
SecretKeyPath string
|
||||
|
||||
TestGoogleDriveClientID string `env:"TEST_GOOGLE_DRIVE_CLIENT_ID"`
|
||||
TestGoogleDriveClientSecret string `env:"TEST_GOOGLE_DRIVE_CLIENT_SECRET"`
|
||||
TestGoogleDriveTokenJSON string `env:"TEST_GOOGLE_DRIVE_TOKEN_JSON"`
|
||||
|
||||
TestPostgres12Port string `env:"TEST_POSTGRES_12_PORT"`
|
||||
TestPostgres13Port string `env:"TEST_POSTGRES_13_PORT"`
|
||||
TestPostgres14Port string `env:"TEST_POSTGRES_14_PORT"`
|
||||
TestPostgres15Port string `env:"TEST_POSTGRES_15_PORT"`
|
||||
TestPostgres16Port string `env:"TEST_POSTGRES_16_PORT"`
|
||||
TestPostgres17Port string `env:"TEST_POSTGRES_17_PORT"`
|
||||
TestPostgres18Port string `env:"TEST_POSTGRES_18_PORT"`
|
||||
|
||||
TestMinioPort string `env:"TEST_MINIO_PORT"`
|
||||
TestMinioConsolePort string `env:"TEST_MINIO_CONSOLE_PORT"`
|
||||
|
||||
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
|
||||
|
||||
TestNASPort string `env:"TEST_NAS_PORT"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET"`
|
||||
GoogleClientID string `env:"GOOGLE_CLIENT_ID"`
|
||||
GoogleClientSecret string `env:"GOOGLE_CLIENT_SECRET"`
|
||||
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -132,8 +147,13 @@ func loadEnvVariables() {
|
||||
// (projectRoot/postgresus-data -> /postgresus-data)
|
||||
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "backups")
|
||||
env.TempFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "temp")
|
||||
env.SecretKeyPath = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "secret.key")
|
||||
|
||||
if env.IsTesting {
|
||||
if env.TestPostgres12Port == "" {
|
||||
log.Error("TEST_POSTGRES_12_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.TestPostgres13Port == "" {
|
||||
log.Error("TEST_POSTGRES_13_PORT is empty")
|
||||
os.Exit(1)
|
||||
@@ -154,6 +174,10 @@ func loadEnvVariables() {
|
||||
log.Error("TEST_POSTGRES_17_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
if env.TestPostgres18Port == "" {
|
||||
log.Error("TEST_POSTGRES_18_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestMinioPort == "" {
|
||||
log.Error("TEST_MINIO_PORT is empty")
|
||||
@@ -164,10 +188,25 @@ func loadEnvVariables() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestAzuriteBlobPort == "" {
|
||||
log.Error("TEST_AZURITE_BLOB_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestNASPort == "" {
|
||||
log.Error("TEST_NAS_PORT is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramBotToken == "" {
|
||||
log.Error("TEST_TELEGRAM_BOT_TOKEN is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if env.TestTelegramChatID == "" {
|
||||
log.Error("TEST_TELEGRAM_CHAT_ID is empty")
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
log.Info("Environment variables loaded successfully!")
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type DowndetectController struct {
|
||||
service *DowndetectService
|
||||
}
|
||||
|
||||
func (c *DowndetectController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/downdetect/is-available", c.IsAvailable)
|
||||
}
|
||||
|
||||
// @Summary Check API availability
|
||||
// @Description Checks if the API service is available
|
||||
// @Tags downdetect
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200
|
||||
// @Failure 500
|
||||
// @Router /downdetect/api [get]
|
||||
func (c *DowndetectController) IsAvailable(ctx *gin.Context) {
|
||||
err := c.service.IsDbAvailable()
|
||||
if err != nil {
|
||||
ctx.JSON(
|
||||
http.StatusInternalServerError,
|
||||
gin.H{"error": fmt.Sprintf("Database is not available: %v", err)},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"message": "API and DB are available"})
|
||||
}
|
||||
@@ -1,10 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
var downdetectService = &DowndetectService{}
|
||||
var downdetectController = &DowndetectController{
|
||||
downdetectService,
|
||||
}
|
||||
|
||||
func GetDowndetectController() *DowndetectController {
|
||||
return downdetectController
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
package downdetect
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/storage"
|
||||
)
|
||||
|
||||
type DowndetectService struct {
|
||||
}
|
||||
|
||||
func (s *DowndetectService) IsDbAvailable() error {
|
||||
err := storage.GetDb().Exec("SELECT 1").Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
111
backend/internal/features/audit_logs/controller.go
Normal file
111
backend/internal/features/audit_logs/controller.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
user_models "postgresus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogController struct {
|
||||
auditLogService *AuditLogService
|
||||
}
|
||||
|
||||
func (c *AuditLogController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// All audit log endpoints require authentication (handled in main.go)
|
||||
auditRoutes := router.Group("/audit-logs")
|
||||
|
||||
auditRoutes.GET("/global", c.GetGlobalAuditLogs)
|
||||
auditRoutes.GET("/users/:userId", c.GetUserAuditLogs)
|
||||
}
|
||||
|
||||
// GetGlobalAuditLogs
|
||||
// @Summary Get global audit logs (ADMIN only)
|
||||
// @Description Retrieve all audit logs across the system
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/global [get]
|
||||
func (c *AuditLogController) GetGlobalAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetGlobalAuditLogs(user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "only administrators can view global audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// GetUserAuditLogs
|
||||
// @Summary Get user audit logs
|
||||
// @Description Retrieve audit logs for a specific user
|
||||
// @Tags audit-logs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param userId path string true "User ID"
|
||||
// @Param limit query int false "Limit number of results" default(100)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Param beforeDate query string false "Filter logs created before this date (RFC3339 format)" format(date-time)
|
||||
// @Success 200 {object} GetAuditLogsResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /audit-logs/users/{userId} [get]
|
||||
func (c *AuditLogController) GetUserAuditLogs(ctx *gin.Context) {
|
||||
user, isOk := ctx.MustGet("user").(*user_models.User)
|
||||
if !isOk {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid user type in context"})
|
||||
return
|
||||
}
|
||||
|
||||
userIDStr := ctx.Param("userId")
|
||||
targetUserID, err := uuid.Parse(userIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
request := &GetAuditLogsRequest{}
|
||||
if err := ctx.ShouldBindQuery(request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "Invalid query parameters"})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := c.auditLogService.GetUserAuditLogs(targetUserID, user, request)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view user audit logs" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to retrieve audit logs"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
154
backend/internal/features/audit_logs/controller_test.go
Normal file
154
backend/internal/features/audit_logs/controller_test.go
Normal file
@@ -0,0 +1,154 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
memberUser := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
workspaceID := uuid.New()
|
||||
testID := uuid.New().String()
|
||||
|
||||
// Create test logs with unique identifiers
|
||||
userLogMessage := fmt.Sprintf("Test log with user %s", testID)
|
||||
workspaceLogMessage := fmt.Sprintf("Test log with workspace %s", testID)
|
||||
standaloneLogMessage := fmt.Sprintf("Test log standalone %s", testID)
|
||||
|
||||
createAuditLog(service, userLogMessage, &adminUser.UserID, nil)
|
||||
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
|
||||
createAuditLog(service, standaloneLogMessage, nil, nil)
|
||||
|
||||
// Test ADMIN can access global logs
|
||||
var response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
"/api/v1/audit-logs/global?limit=100", "Bearer "+adminUser.Token, http.StatusOK, &response)
|
||||
|
||||
// Verify our specific test logs are present
|
||||
messages := extractMessages(response.AuditLogs)
|
||||
assert.Contains(t, messages, userLogMessage)
|
||||
assert.Contains(t, messages, workspaceLogMessage)
|
||||
assert.Contains(t, messages, standaloneLogMessage)
|
||||
|
||||
// Test MEMBER cannot access global logs
|
||||
resp := test_utils.MakeGetRequest(t, router, "/api/v1/audit-logs/global",
|
||||
"Bearer "+memberUser.Token, http.StatusForbidden)
|
||||
assert.Contains(t, string(resp.Body), "only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
func Test_GetUserAuditLogs_WithDifferentUserRoles_EnforcesPermissionsCorrectly(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
router := createRouter()
|
||||
service := GetAuditLogService()
|
||||
workspaceID := uuid.New()
|
||||
testID := uuid.New().String()
|
||||
|
||||
// Create test logs for different users with unique identifiers
|
||||
user1FirstMessage := fmt.Sprintf("Test log user1 first %s", testID)
|
||||
user1SecondMessage := fmt.Sprintf("Test log user1 second %s", testID)
|
||||
user2FirstMessage := fmt.Sprintf("Test log user2 first %s", testID)
|
||||
user2SecondMessage := fmt.Sprintf("Test log user2 second %s", testID)
|
||||
workspaceLogMessage := fmt.Sprintf("Test workspace log %s", testID)
|
||||
|
||||
createAuditLog(service, user1FirstMessage, &user1.UserID, nil)
|
||||
createAuditLog(service, user1SecondMessage, &user1.UserID, &workspaceID)
|
||||
createAuditLog(service, user2FirstMessage, &user2.UserID, nil)
|
||||
createAuditLog(service, user2SecondMessage, &user2.UserID, &workspaceID)
|
||||
createAuditLog(service, workspaceLogMessage, nil, &workspaceID)
|
||||
|
||||
// Test ADMIN can view any user's logs
|
||||
var user1Response GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user1.UserID.String()),
|
||||
"Bearer "+adminUser.Token, http.StatusOK, &user1Response)
|
||||
|
||||
// Verify user1's specific logs are present
|
||||
messages := extractMessages(user1Response.AuditLogs)
|
||||
assert.Contains(t, messages, user1FirstMessage)
|
||||
assert.Contains(t, messages, user1SecondMessage)
|
||||
|
||||
// Count only our test logs for user1
|
||||
testLogsCount := 0
|
||||
for _, message := range messages {
|
||||
if message == user1FirstMessage || message == user1SecondMessage {
|
||||
testLogsCount++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 2, testLogsCount)
|
||||
|
||||
// Test user can view own logs
|
||||
var ownLogsResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s?limit=100", user2.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusOK, &ownLogsResponse)
|
||||
|
||||
// Verify user2's specific logs are present
|
||||
ownMessages := extractMessages(ownLogsResponse.AuditLogs)
|
||||
assert.Contains(t, ownMessages, user2FirstMessage)
|
||||
assert.Contains(t, ownMessages, user2SecondMessage)
|
||||
|
||||
// Test user cannot view other user's logs
|
||||
resp := test_utils.MakeGetRequest(t, router,
|
||||
fmt.Sprintf("/api/v1/audit-logs/users/%s", user1.UserID.String()),
|
||||
"Bearer "+user2.Token, http.StatusForbidden)
|
||||
|
||||
assert.Contains(t, string(resp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetGlobalAuditLogs_WithBeforeDateFilter_ReturnsFilteredLogs(t *testing.T) {
|
||||
adminUser := users_testing.CreateTestUser(user_enums.UserRoleAdmin)
|
||||
router := createRouter()
|
||||
baseTime := time.Now().UTC()
|
||||
|
||||
// Set filter time to 30 minutes ago
|
||||
beforeTime := baseTime.Add(-30 * time.Minute)
|
||||
|
||||
var filteredResponse GetAuditLogsResponse
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/audit-logs/global?beforeDate=%s&limit=1000",
|
||||
beforeTime.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+adminUser.Token,
|
||||
http.StatusOK,
|
||||
&filteredResponse,
|
||||
)
|
||||
|
||||
// Verify ALL returned logs are older than the filter time
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime),
|
||||
fmt.Sprintf("Log created at %s should be before filter time %s",
|
||||
log.CreatedAt.Format(time.RFC3339), beforeTime.Format(time.RFC3339)))
|
||||
}
|
||||
}
|
||||
|
||||
func createRouter() *gin.Engine {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
SetupDependencies()
|
||||
|
||||
v1 := router.Group("/api/v1")
|
||||
protected := v1.Group("").Use(users_middleware.AuthMiddleware(users_services.GetUserService()))
|
||||
GetAuditLogController().RegisterRoutes(protected.(*gin.RouterGroup))
|
||||
|
||||
return router
|
||||
}
|
||||
29
backend/internal/features/audit_logs/di.go
Normal file
29
backend/internal/features/audit_logs/di.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var auditLogRepository = &AuditLogRepository{}
|
||||
var auditLogService = &AuditLogService{
|
||||
auditLogRepository: auditLogRepository,
|
||||
logger: logger.GetLogger(),
|
||||
}
|
||||
var auditLogController = &AuditLogController{
|
||||
auditLogService: auditLogService,
|
||||
}
|
||||
|
||||
func GetAuditLogService() *AuditLogService {
|
||||
return auditLogService
|
||||
}
|
||||
|
||||
func GetAuditLogController() *AuditLogController {
|
||||
return auditLogController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
users_services.GetUserService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetSettingsService().SetAuditLogWriter(auditLogService)
|
||||
users_services.GetManagementService().SetAuditLogWriter(auditLogService)
|
||||
}
|
||||
31
backend/internal/features/audit_logs/dto.go
Normal file
31
backend/internal/features/audit_logs/dto.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type GetAuditLogsRequest struct {
|
||||
Limit int `form:"limit" json:"limit"`
|
||||
Offset int `form:"offset" json:"offset"`
|
||||
BeforeDate *time.Time `form:"beforeDate" json:"beforeDate"`
|
||||
}
|
||||
|
||||
type GetAuditLogsResponse struct {
|
||||
AuditLogs []*AuditLogDTO `json:"auditLogs"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type AuditLogDTO struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
UserEmail *string `json:"userEmail" gorm:"column:user_email"`
|
||||
UserName *string `json:"userName" gorm:"column:user_name"`
|
||||
WorkspaceName *string `json:"workspaceName" gorm:"column:workspace_name"`
|
||||
}
|
||||
19
backend/internal/features/audit_logs/models.go
Normal file
19
backend/internal/features/audit_logs/models.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLog struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id"`
|
||||
UserID *uuid.UUID `json:"userId" gorm:"column:user_id"`
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id"`
|
||||
Message string `json:"message" gorm:"column:message"`
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
func (AuditLog) TableName() string {
|
||||
return "audit_logs"
|
||||
}
|
||||
139
backend/internal/features/audit_logs/repository.go
Normal file
139
backend/internal/features/audit_logs/repository.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/storage"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogRepository struct{}
|
||||
|
||||
func (r *AuditLogRepository) Create(auditLog *AuditLog) error {
|
||||
if auditLog.ID == uuid.Nil {
|
||||
auditLog.ID = uuid.New()
|
||||
}
|
||||
|
||||
return storage.GetDb().Create(auditLog).Error
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetGlobal(
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id`
|
||||
|
||||
args := []interface{}{}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " WHERE al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByUser(
|
||||
userID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.user_id = ?`
|
||||
|
||||
args := []interface{}{userID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) GetByWorkspace(
|
||||
workspaceID uuid.UUID,
|
||||
limit, offset int,
|
||||
beforeDate *time.Time,
|
||||
) ([]*AuditLogDTO, error) {
|
||||
var auditLogs = make([]*AuditLogDTO, 0)
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
al.id,
|
||||
al.user_id,
|
||||
al.workspace_id,
|
||||
al.message,
|
||||
al.created_at,
|
||||
u.email as user_email,
|
||||
u.name as user_name,
|
||||
w.name as workspace_name
|
||||
FROM audit_logs al
|
||||
LEFT JOIN users u ON al.user_id = u.id
|
||||
LEFT JOIN workspaces w ON al.workspace_id = w.id
|
||||
WHERE al.workspace_id = ?`
|
||||
|
||||
args := []interface{}{workspaceID}
|
||||
|
||||
if beforeDate != nil {
|
||||
sql += " AND al.created_at < ?"
|
||||
args = append(args, *beforeDate)
|
||||
}
|
||||
|
||||
sql += " ORDER BY al.created_at DESC LIMIT ? OFFSET ?"
|
||||
args = append(args, limit, offset)
|
||||
|
||||
err := storage.GetDb().Raw(sql, args...).Scan(&auditLogs).Error
|
||||
|
||||
return auditLogs, err
|
||||
}
|
||||
|
||||
func (r *AuditLogRepository) CountGlobal(beforeDate *time.Time) (int64, error) {
|
||||
var count int64
|
||||
query := storage.GetDb().Model(&AuditLog{})
|
||||
|
||||
if beforeDate != nil {
|
||||
query = query.Where("created_at < ?", *beforeDate)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
137
backend/internal/features/audit_logs/service.go
Normal file
137
backend/internal/features/audit_logs/service.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
user_models "postgresus-backend/internal/features/users/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AuditLogService struct {
|
||||
auditLogRepository *AuditLogRepository
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *AuditLogService) WriteAuditLog(
|
||||
message string,
|
||||
userID *uuid.UUID,
|
||||
workspaceID *uuid.UUID,
|
||||
) {
|
||||
auditLog := &AuditLog{
|
||||
UserID: userID,
|
||||
WorkspaceID: workspaceID,
|
||||
Message: message,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err := s.auditLogRepository.Create(auditLog)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to create audit log", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuditLogService) CreateAuditLog(auditLog *AuditLog) error {
|
||||
return s.auditLogRepository.Create(auditLog)
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetGlobalAuditLogs(
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
if user.Role != user_enums.UserRoleAdmin {
|
||||
return nil, errors.New("only administrators can view global audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetGlobal(limit, offset, request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
total, err := s.auditLogRepository.CountGlobal(request.BeforeDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetUserAuditLogs(
|
||||
targetUserID uuid.UUID,
|
||||
user *user_models.User,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
// Users can view their own logs, ADMIN can view any user's logs
|
||||
if user.Role != user_enums.UserRoleAdmin && user.ID != targetUserID {
|
||||
return nil, errors.New("insufficient permissions to view user audit logs")
|
||||
}
|
||||
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByUser(
|
||||
targetUserID,
|
||||
limit,
|
||||
offset,
|
||||
request.BeforeDate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *AuditLogService) GetWorkspaceAuditLogs(
|
||||
workspaceID uuid.UUID,
|
||||
request *GetAuditLogsRequest,
|
||||
) (*GetAuditLogsResponse, error) {
|
||||
limit := request.Limit
|
||||
if limit <= 0 || limit > 1000 {
|
||||
limit = 100
|
||||
}
|
||||
|
||||
offset := max(request.Offset, 0)
|
||||
|
||||
auditLogs, err := s.auditLogRepository.GetByWorkspace(
|
||||
workspaceID,
|
||||
limit,
|
||||
offset,
|
||||
request.BeforeDate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetAuditLogsResponse{
|
||||
AuditLogs: auditLogs,
|
||||
Total: int64(len(auditLogs)),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
83
backend/internal/features/audit_logs/service_test.go
Normal file
83
backend/internal/features/audit_logs/service_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package audit_logs
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
user_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_AuditLogs_WorkspaceSpecificLogs(t *testing.T) {
|
||||
service := GetAuditLogService()
|
||||
user1 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
user2 := users_testing.CreateTestUser(user_enums.UserRoleMember)
|
||||
workspace1ID, workspace2ID := uuid.New(), uuid.New()
|
||||
|
||||
// Create test logs for workspaces
|
||||
createAuditLog(service, "Test workspace1 log first", &user1.UserID, &workspace1ID)
|
||||
createAuditLog(service, "Test workspace1 log second", &user2.UserID, &workspace1ID)
|
||||
createAuditLog(service, "Test workspace2 log first", &user1.UserID, &workspace2ID)
|
||||
createAuditLog(service, "Test workspace2 log second", &user2.UserID, &workspace2ID)
|
||||
createAuditLog(service, "Test no workspace log", &user1.UserID, nil)
|
||||
|
||||
request := &GetAuditLogsRequest{Limit: 10, Offset: 0}
|
||||
|
||||
// Test workspace 1 logs
|
||||
workspace1Response, err := service.GetWorkspaceAuditLogs(workspace1ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(workspace1Response.AuditLogs))
|
||||
|
||||
messages := extractMessages(workspace1Response.AuditLogs)
|
||||
assert.Contains(t, messages, "Test workspace1 log first")
|
||||
assert.Contains(t, messages, "Test workspace1 log second")
|
||||
for _, log := range workspace1Response.AuditLogs {
|
||||
assert.Equal(t, &workspace1ID, log.WorkspaceID)
|
||||
}
|
||||
|
||||
// Test workspace 2 logs
|
||||
workspace2Response, err := service.GetWorkspaceAuditLogs(workspace2ID, request)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(workspace2Response.AuditLogs))
|
||||
|
||||
messages2 := extractMessages(workspace2Response.AuditLogs)
|
||||
assert.Contains(t, messages2, "Test workspace2 log first")
|
||||
assert.Contains(t, messages2, "Test workspace2 log second")
|
||||
|
||||
// Test pagination
|
||||
limitedResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
|
||||
&GetAuditLogsRequest{Limit: 1, Offset: 0})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(limitedResponse.AuditLogs))
|
||||
assert.Equal(t, 1, limitedResponse.Limit)
|
||||
|
||||
// Test beforeDate filter
|
||||
beforeTime := time.Now().UTC().Add(-1 * time.Minute)
|
||||
filteredResponse, err := service.GetWorkspaceAuditLogs(workspace1ID,
|
||||
&GetAuditLogsRequest{Limit: 10, BeforeDate: &beforeTime})
|
||||
assert.NoError(t, err)
|
||||
for _, log := range filteredResponse.AuditLogs {
|
||||
assert.True(t, log.CreatedAt.Before(beforeTime))
|
||||
assert.NotNil(t, log.UserEmail, "User email should be present for logs with user_id")
|
||||
assert.NotNil(
|
||||
t,
|
||||
log.WorkspaceName,
|
||||
"Workspace name should be present for logs with workspace_id",
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func createAuditLog(service *AuditLogService, message string, userID, workspaceID *uuid.UUID) {
|
||||
service.WriteAuditLog(message, userID, workspaceID)
|
||||
}
|
||||
|
||||
func extractMessages(logs []*AuditLogDTO) []string {
|
||||
messages := make([]string, len(logs))
|
||||
for i, log := range logs {
|
||||
messages[i] = log.Message
|
||||
}
|
||||
return messages
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"postgresus-backend/internal/config"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/period"
|
||||
"time"
|
||||
)
|
||||
@@ -131,7 +132,8 @@ func (s *BackupBackgroundService) cleanOldBackups() error {
|
||||
continue
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
err = storage.DeleteFile(encryptor, backup.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to delete backup file", "backupId", backup.ID, "error", err)
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/period"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,10 +18,12 @@ import (
|
||||
|
||||
func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -40,11 +44,8 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
|
||||
// add old backup
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
@@ -67,16 +68,20 @@ func Test_MakeBackupForDbHavingBackupDayAgo_BackupCreated(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -97,11 +102,8 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
|
||||
// add recent backup (1 hour ago)
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusCompleted,
|
||||
|
||||
@@ -124,16 +126,20 @@ func Test_MakeBackupForDbHavingHourAgoBackup_BackupSkipped(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries disabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -157,11 +163,8 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -185,16 +188,20 @@ func Test_MakeBackupHavingFailedBackupWithoutRetries_BackupSkipped(t *testing.T)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -218,11 +225,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
// add failed backup
|
||||
failMessage := "backup failed"
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -246,16 +250,20 @@ func Test_MakeBackupHavingFailedBackupWithRetries_BackupCreated(t *testing.T) {
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(100 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *testing.T) {
|
||||
// setup data
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
|
||||
// Enable backups for the database with retries enabled
|
||||
backupConfig, err := backups_config.GetBackupConfigService().GetBackupConfigByDbId(database.ID)
|
||||
@@ -280,11 +288,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
backupRepository.Save(&Backup{
|
||||
Database: database,
|
||||
DatabaseID: database.ID,
|
||||
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusFailed,
|
||||
FailMessage: &failMessage,
|
||||
@@ -309,6 +314,8 @@ func Test_MakeBackupHavingFailedBackupWithRetries_RetriesCountNotExceeded(t *tes
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc context.CancelFunc) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if !exists {
|
||||
return errors.New("backup is not in progress or already completed")
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
|
||||
type BackupController struct {
|
||||
backupService *BackupService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -20,51 +19,48 @@ func (c *BackupController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.POST("/backups", c.MakeBackup)
|
||||
router.GET("/backups/:id/file", c.GetFile)
|
||||
router.DELETE("/backups/:id", c.DeleteBackup)
|
||||
router.POST("/backups/:id/cancel", c.CancelBackup)
|
||||
}
|
||||
|
||||
// GetBackups
|
||||
// @Summary Get backups for a database
|
||||
// @Description Get all backups for the specified database
|
||||
// @Description Get paginated backups for the specified database
|
||||
// @Tags backups
|
||||
// @Produce json
|
||||
// @Param database_id query string true "Database ID"
|
||||
// @Success 200 {array} Backup
|
||||
// @Param limit query int false "Number of items per page" default(10)
|
||||
// @Param offset query int false "Offset for pagination" default(0)
|
||||
// @Success 200 {object} GetBackupsResponse
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups [get]
|
||||
func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
databaseIDStr := ctx.Query("database_id")
|
||||
if databaseIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "database_id query parameter is required"})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(databaseIDStr)
|
||||
var request GetBackupsRequest
|
||||
if err := ctx.ShouldBindQuery(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
databaseID, err := uuid.Parse(request.DatabaseID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database_id"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
backups, err := c.backupService.GetBackups(user, databaseID)
|
||||
response, err := c.backupService.GetBackups(user, databaseID, request.Limit, request.Offset)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, backups)
|
||||
ctx.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// MakeBackup
|
||||
@@ -80,24 +76,18 @@ func (c *BackupController) GetBackups(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups [post]
|
||||
func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request MakeBackupRequest
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.MakeBackupWithAuth(user, request.DatabaseID); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -117,24 +107,18 @@ func (c *BackupController) MakeBackup(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups/{id} [delete]
|
||||
func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.DeleteBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -143,6 +127,37 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// CancelBackup
|
||||
// @Summary Cancel an in-progress backup
|
||||
// @Description Cancel a backup that is currently in progress
|
||||
// @Tags backups
|
||||
// @Param id path string true "Backup ID"
|
||||
// @Success 204
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/cancel [post]
|
||||
func (c *BackupController) CancelBackup(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.backupService.CancelBackup(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Status(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// GetFile
|
||||
// @Summary Download a backup file
|
||||
// @Description Download the backup file for the specified backup
|
||||
@@ -154,24 +169,18 @@ func (c *BackupController) DeleteBackup(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backups/{id}/file [get]
|
||||
func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid backup ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, err := c.backupService.GetBackupFile(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -179,19 +188,16 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
}
|
||||
defer func() {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
// Log the error but don't interrupt the response
|
||||
fmt.Printf("Error closing file reader: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set headers for file download
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"backup_%s.dump\"", id.String()),
|
||||
)
|
||||
|
||||
// Stream the file content
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": "failed to stream file"})
|
||||
|
||||
709
backend/internal/features/backups/backups/controller_test.go
Normal file
709
backend/internal/features/backups/backups/controller_test.go
Normal file
@@ -0,0 +1,709 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
users_dto "postgresus-backend/internal/features/users/dto"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_models "postgresus-backend/internal/features/workspaces/models"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_GetBackups_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get backups",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can get backups",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, _ := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups?database_id=%s", database.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response GetBackupsResponse
|
||||
err := json.Unmarshal(testResp.Body, &response)
|
||||
assert.NoError(t, err)
|
||||
assert.GreaterOrEqual(t, len(response.Backups), 1)
|
||||
assert.GreaterOrEqual(t, response.Total, int64(1))
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can create backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can create backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "backup started successfully")
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
enableBackupForDatabase(database.ID)
|
||||
|
||||
request := MakeBackupRequest{DatabaseID: database.ID}
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backups",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup manually initiated") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup creation not found")
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace member can delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot delete backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can delete backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusNoContent,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
} else {
|
||||
userService := users_services.GetUserService()
|
||||
ownerUser, err := userService.GetUserFromToken(owner.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
response, err := GetBackupService().GetBackups(ownerUser, database.ID, 10, 0)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 0, len(response.Backups))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DeleteBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup deleted") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup deletion not found")
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace viewer can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can download backup",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can download backup",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
if *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
|
||||
if !tt.expectSuccess {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_DownloadBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database, backup := createTestDatabaseWithBackups(workspace, owner, router)
|
||||
|
||||
test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/file", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetWorkspaceAuditLogs(
|
||||
workspace.ID,
|
||||
&audit_logs.GetAuditLogsRequest{
|
||||
Limit: 100,
|
||||
Offset: 0,
|
||||
},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
found := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup file downloaded") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "Audit log for backup download not found")
|
||||
}
|
||||
|
||||
func Test_CancelBackup_InProgressBackup_SuccessfullyCancelled(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
assert.NoError(t, err)
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
assert.NoError(t, err)
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storage.ID,
|
||||
Status: BackupStatusInProgress,
|
||||
BackupSizeMb: 0,
|
||||
BackupDurationMs: 0,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
err = repo.Save(backup)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Register a cancellable context for the backup
|
||||
GetBackupService().backupContextManager.RegisterBackup(backup.ID, func() {})
|
||||
|
||||
resp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/backups/%s/cancel", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
nil,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, resp.StatusCode)
|
||||
|
||||
// Verify audit log was created
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
userService := users_services.GetUserService()
|
||||
adminUser, err := userService.GetUserFromToken(admin.Token)
|
||||
assert.NoError(t, err)
|
||||
|
||||
auditLogService := audit_logs.GetAuditLogService()
|
||||
auditLogs, err := auditLogService.GetGlobalAuditLogs(
|
||||
adminUser,
|
||||
&audit_logs.GetAuditLogsRequest{Limit: 100, Offset: 0},
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
foundCancelLog := false
|
||||
for _, log := range auditLogs.AuditLogs {
|
||||
if strings.Contains(log.Message, "Backup cancelled") &&
|
||||
strings.Contains(log.Message, database.Name) {
|
||||
foundCancelLog = true
|
||||
break
|
||||
}
|
||||
}
|
||||
assert.True(t, foundCancelLog, "Cancel audit log should be created")
|
||||
}
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
return CreateTestRouter()
|
||||
}
|
||||
|
||||
func createTestDatabase(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic(
|
||||
fmt.Sprintf("Failed to create database. Status: %d, Body: %s", w.Code, w.Body.String()),
|
||||
)
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
storage := &storages.Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: storages.StorageTypeLocal,
|
||||
Name: "Test Storage " + uuid.New().String(),
|
||||
LocalStorage: &local_storage.LocalStorage{},
|
||||
}
|
||||
|
||||
repo := &storages.StorageRepository{}
|
||||
storage, err := repo.Save(storage)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
func enableBackupForDatabase(databaseID uuid.UUID) {
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(databaseID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func createTestDatabaseWithBackups(
|
||||
workspace *workspaces_models.Workspace,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
router *gin.Engine,
|
||||
) (*databases.Database, *Backup) {
|
||||
database := createTestDatabase("Test Database", workspace.ID, owner.Token, router)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
configService := backups_config.GetBackupConfigService()
|
||||
config, err := configService.GetBackupConfigByDbId(database.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
config.IsBackupsEnabled = true
|
||||
config.StorageID = &storage.ID
|
||||
config.Storage = storage
|
||||
_, err = configService.SaveBackupConfig(config)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
backup := createTestBackup(database, owner)
|
||||
|
||||
return database, backup
|
||||
}
|
||||
|
||||
func createTestBackup(
|
||||
database *databases.Database,
|
||||
owner *users_dto.SignInResponseDTO,
|
||||
) *Backup {
|
||||
userService := users_services.GetUserService()
|
||||
user, err := userService.GetUserFromToken(owner.Token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
storages, err := storages.GetStorageService().GetStorages(user, *database.WorkspaceID)
|
||||
if err != nil || len(storages) == 0 {
|
||||
panic("No storage found for workspace")
|
||||
}
|
||||
|
||||
backup := &Backup{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: database.ID,
|
||||
StorageID: storages[0].ID,
|
||||
Status: BackupStatusCompleted,
|
||||
BackupSizeMb: 10.5,
|
||||
BackupDurationMs: 1000,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
repo := &BackupRepository{}
|
||||
if err := repo.Save(backup); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Create a dummy backup file for testing download functionality
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
return backup
|
||||
}
|
||||
@@ -1,17 +1,24 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups/usecases"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"time"
|
||||
)
|
||||
|
||||
var backupRepository = &BackupRepository{}
|
||||
|
||||
var backupContextManager = NewBackupContextManager()
|
||||
|
||||
var backupService = &BackupService{
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
@@ -19,9 +26,14 @@ var backupService = &BackupService{
|
||||
notifiers.GetNotifierService(),
|
||||
notifiers.GetNotifierService(),
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
usecases.GetCreateBackupUsecase(),
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
backupContextManager,
|
||||
}
|
||||
|
||||
var backupBackgroundService = &BackupBackgroundService{
|
||||
@@ -35,7 +47,6 @@ var backupBackgroundService = &BackupBackgroundService{
|
||||
|
||||
var backupController = &BackupController{
|
||||
backupService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
@@ -44,6 +55,7 @@ func SetupDependencies() {
|
||||
SetDatabaseStorageChangeListener(backupService)
|
||||
|
||||
databases.GetDatabaseService().AddDbRemoveListener(backupService)
|
||||
databases.GetDatabaseService().AddDbCopyListener(backups_config.GetBackupConfigService())
|
||||
}
|
||||
|
||||
func GetBackupService() *BackupService {
|
||||
|
||||
28
backend/internal/features/backups/backups/dto.go
Normal file
28
backend/internal/features/backups/backups/dto.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"io"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
)
|
||||
|
||||
type GetBackupsRequest struct {
|
||||
DatabaseID string `form:"database_id" binding:"required"`
|
||||
Limit int `form:"limit"`
|
||||
Offset int `form:"offset"`
|
||||
}
|
||||
|
||||
type GetBackupsResponse struct {
|
||||
Backups []*Backup `json:"backups"`
|
||||
Total int64 `json:"total"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
}
|
||||
|
||||
type decryptionReaderCloser struct {
|
||||
*encryption.DecryptionReader
|
||||
baseReader io.ReadCloser
|
||||
}
|
||||
|
||||
func (r *decryptionReaderCloser) Close() error {
|
||||
return r.baseReader.Close()
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DecryptionReader struct {
|
||||
baseReader io.Reader
|
||||
cipher cipher.AEAD
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
chunkIndex uint64
|
||||
headerRead bool
|
||||
eof bool
|
||||
}
|
||||
|
||||
func NewDecryptionReader(
|
||||
baseReader io.Reader,
|
||||
masterKey string,
|
||||
backupID uuid.UUID,
|
||||
salt []byte,
|
||||
nonce []byte,
|
||||
) (*DecryptionReader, error) {
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
|
||||
}
|
||||
if len(nonce) != NonceLen {
|
||||
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive backup key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(derivedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
reader := &DecryptionReader{
|
||||
baseReader,
|
||||
aesgcm,
|
||||
make([]byte, 0),
|
||||
nonce,
|
||||
0,
|
||||
false,
|
||||
false,
|
||||
}
|
||||
|
||||
if err := reader.readAndValidateHeader(salt, nonce); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) Read(p []byte) (n int, err error) {
|
||||
for len(r.buffer) < len(p) && !r.eof {
|
||||
if err := r.readAndDecryptChunk(); err != nil {
|
||||
if err == io.EOF {
|
||||
r.eof = true
|
||||
break
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.buffer) == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
n = copy(p, r.buffer)
|
||||
r.buffer = r.buffer[n:]
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) readAndValidateHeader(expectedSalt, expectedNonce []byte) error {
|
||||
header := make([]byte, HeaderLen)
|
||||
|
||||
if _, err := io.ReadFull(r.baseReader, header); err != nil {
|
||||
return fmt.Errorf("failed to read header: %w", err)
|
||||
}
|
||||
|
||||
magic := string(header[0:MagicBytesLen])
|
||||
if magic != MagicBytes {
|
||||
return fmt.Errorf("invalid magic bytes: expected %s, got %s", MagicBytes, magic)
|
||||
}
|
||||
|
||||
salt := header[MagicBytesLen : MagicBytesLen+SaltLen]
|
||||
nonce := header[MagicBytesLen+SaltLen : MagicBytesLen+SaltLen+NonceLen]
|
||||
|
||||
if string(salt) != string(expectedSalt) {
|
||||
return fmt.Errorf("salt mismatch in file header")
|
||||
}
|
||||
|
||||
if string(nonce) != string(expectedNonce) {
|
||||
return fmt.Errorf("nonce mismatch in file header")
|
||||
}
|
||||
|
||||
r.headerRead = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) readAndDecryptChunk() error {
|
||||
lengthBuf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(r.baseReader, lengthBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chunkLen := binary.BigEndian.Uint32(lengthBuf)
|
||||
if chunkLen == 0 || chunkLen > ChunkSize+16 {
|
||||
return fmt.Errorf("invalid chunk length: %d", chunkLen)
|
||||
}
|
||||
|
||||
encrypted := make([]byte, chunkLen)
|
||||
if _, err := io.ReadFull(r.baseReader, encrypted); err != nil {
|
||||
return fmt.Errorf("failed to read encrypted chunk: %w", err)
|
||||
}
|
||||
|
||||
chunkNonce := r.generateChunkNonce()
|
||||
|
||||
decrypted, err := r.cipher.Open(nil, chunkNonce, encrypted, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf(
|
||||
"failed to decrypt chunk (authentication failed - file may be corrupted or tampered): %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
r.buffer = append(r.buffer, decrypted...)
|
||||
r.chunkIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DecryptionReader) generateChunkNonce() []byte {
|
||||
chunkNonce := make([]byte, NonceLen)
|
||||
copy(chunkNonce, r.nonce)
|
||||
|
||||
binary.BigEndian.PutUint64(chunkNonce[4:], r.chunkIndex)
|
||||
|
||||
return chunkNonce
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type EncryptionWriter struct {
|
||||
baseWriter io.Writer
|
||||
cipher cipher.AEAD
|
||||
buffer []byte
|
||||
nonce []byte
|
||||
salt []byte
|
||||
chunkIndex uint64
|
||||
headerWritten bool
|
||||
}
|
||||
|
||||
func NewEncryptionWriter(
|
||||
baseWriter io.Writer,
|
||||
masterKey string,
|
||||
backupID uuid.UUID,
|
||||
salt []byte,
|
||||
nonce []byte,
|
||||
) (*EncryptionWriter, error) {
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes, got %d", SaltLen, len(salt))
|
||||
}
|
||||
if len(nonce) != NonceLen {
|
||||
return nil, fmt.Errorf("nonce must be %d bytes, got %d", NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
derivedKey, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to derive backup key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(derivedKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
aesgcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
writer := &EncryptionWriter{
|
||||
baseWriter: baseWriter,
|
||||
cipher: aesgcm,
|
||||
buffer: make([]byte, 0, ChunkSize),
|
||||
nonce: nonce,
|
||||
chunkIndex: 0,
|
||||
headerWritten: false,
|
||||
salt: salt, // Store salt for lazy header writing
|
||||
}
|
||||
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) Write(p []byte) (n int, err error) {
|
||||
// Write header on first write (lazy initialization)
|
||||
if !w.headerWritten {
|
||||
if err := w.writeHeader(w.salt, w.nonce); err != nil {
|
||||
return 0, fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
n = len(p)
|
||||
w.buffer = append(w.buffer, p...)
|
||||
|
||||
for len(w.buffer) >= ChunkSize {
|
||||
chunk := w.buffer[:ChunkSize]
|
||||
if err := w.encryptAndWriteChunk(chunk); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
w.buffer = w.buffer[ChunkSize:]
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) Close() error {
|
||||
// Write header if it hasn't been written yet (in case Close is called without any writes)
|
||||
if !w.headerWritten {
|
||||
if err := w.writeHeader(w.salt, w.nonce); err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(w.buffer) > 0 {
|
||||
if err := w.encryptAndWriteChunk(w.buffer); err != nil {
|
||||
return err
|
||||
}
|
||||
w.buffer = nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) writeHeader(salt, nonce []byte) error {
|
||||
header := make([]byte, HeaderLen)
|
||||
|
||||
copy(header[0:MagicBytesLen], []byte(MagicBytes))
|
||||
copy(header[MagicBytesLen:MagicBytesLen+SaltLen], salt)
|
||||
copy(header[MagicBytesLen+SaltLen:MagicBytesLen+SaltLen+NonceLen], nonce)
|
||||
|
||||
_, err := w.baseWriter.Write(header)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to write header: %w", err)
|
||||
}
|
||||
|
||||
w.headerWritten = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) encryptAndWriteChunk(chunk []byte) error {
|
||||
chunkNonce := w.generateChunkNonce()
|
||||
|
||||
encrypted := w.cipher.Seal(nil, chunkNonce, chunk, nil)
|
||||
|
||||
lengthBuf := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(lengthBuf, uint32(len(encrypted)))
|
||||
|
||||
if _, err := w.baseWriter.Write(lengthBuf); err != nil {
|
||||
return fmt.Errorf("failed to write chunk length: %w", err)
|
||||
}
|
||||
|
||||
if _, err := w.baseWriter.Write(encrypted); err != nil {
|
||||
return fmt.Errorf("failed to write encrypted chunk: %w", err)
|
||||
}
|
||||
|
||||
w.chunkIndex++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *EncryptionWriter) generateChunkNonce() []byte {
|
||||
chunkNonce := make([]byte, NonceLen)
|
||||
copy(chunkNonce, w.nonce)
|
||||
|
||||
binary.BigEndian.PutUint64(chunkNonce[4:], w.chunkIndex)
|
||||
|
||||
return chunkNonce
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_EncryptDecryptRoundTrip_ReturnsOriginalData(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte(
|
||||
"This is a test backup data that should be encrypted and then decrypted successfully.",
|
||||
)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted := make([]byte, len(originalData))
|
||||
n, err = io.ReadFull(reader, decrypted)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_EncryptDecryptRoundTrip_LargeData_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := make([]byte, 100*1024)
|
||||
_, err = rand.Read(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
n, err := writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(originalData), n)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_MultipleWrites_CombinesCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
part1 := []byte("First part of data. ")
|
||||
part2 := []byte("Second part of data. ")
|
||||
part3 := []byte("Third part of data.")
|
||||
expectedData := append(append(part1, part2...), part3...)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(part1)
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(part2)
|
||||
require.NoError(t, err)
|
||||
_, err = writer.Write(part3)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedData, decrypted)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_InvalidHeader_ReturnsError(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
invalidHeader := make([]byte, HeaderLen)
|
||||
copy(invalidHeader, []byte("INVALID!"))
|
||||
|
||||
invalidData := bytes.NewBuffer(invalidHeader)
|
||||
|
||||
_, err = NewDecryptionReader(invalidData, masterKey, backupID, salt, nonce)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid magic bytes")
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_TamperedData_ReturnsError(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("This data will be tampered with.")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
encryptedBytes := encrypted.Bytes()
|
||||
if len(encryptedBytes) > HeaderLen+10 {
|
||||
encryptedBytes[HeaderLen+10] ^= 0xFF
|
||||
}
|
||||
|
||||
tamperedBuffer := bytes.NewBuffer(encryptedBytes)
|
||||
|
||||
reader, err := NewDecryptionReader(tamperedBuffer, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
}
|
||||
|
||||
func Test_DeriveBackupKey_SameInputs_ReturnsSameKey(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
key1, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
key2, err := DeriveBackupKey(masterKey, backupID, salt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, key1, key2)
|
||||
}
|
||||
|
||||
func Test_DeriveBackupKey_DifferentInputs_ReturnsDifferentKeys(t *testing.T) {
|
||||
masterKey1 := uuid.New().String() + uuid.New().String()
|
||||
masterKey2 := uuid.New().String() + uuid.New().String()
|
||||
backupID1 := uuid.New()
|
||||
backupID2 := uuid.New()
|
||||
salt1, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
salt2, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
key1, err := DeriveBackupKey(masterKey1, backupID1, salt1)
|
||||
require.NoError(t, err)
|
||||
|
||||
key2, err := DeriveBackupKey(masterKey2, backupID1, salt1)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key2)
|
||||
|
||||
key3, err := DeriveBackupKey(masterKey1, backupID2, salt1)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key3)
|
||||
|
||||
key4, err := DeriveBackupKey(masterKey1, backupID1, salt2)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, key1, key4)
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_PartialChunk_HandledCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
smallData := []byte("Small data less than chunk size")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(smallData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, smallData, decrypted)
|
||||
}
|
||||
|
||||
func Test_GenerateSalt_ReturnsCorrectLength(t *testing.T) {
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, SaltLen, len(salt))
|
||||
}
|
||||
|
||||
func Test_GenerateSalt_GeneratesUniqueSalts(t *testing.T) {
|
||||
salt1, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
salt2, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, salt1, salt2)
|
||||
}
|
||||
|
||||
func Test_GenerateNonce_ReturnsCorrectLength(t *testing.T) {
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, NonceLen, len(nonce))
|
||||
}
|
||||
|
||||
func Test_GenerateNonce_GeneratesUniqueNonces(t *testing.T) {
|
||||
nonce1, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
nonce2, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, nonce1, nonce2)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_WrongMasterKey_ReturnsError(t *testing.T) {
|
||||
masterKey1 := uuid.New().String() + uuid.New().String()
|
||||
masterKey2 := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("Secret data")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey1, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey2, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = io.ReadAll(reader)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "authentication failed")
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_EmptyData_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, len(decrypted))
|
||||
}
|
||||
|
||||
func Test_EncryptionWriter_MultipleChunks_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
dataSize := ChunkSize*3 + 1000
|
||||
originalData := make([]byte, dataSize)
|
||||
_, err = rand.Read(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
decrypted, err := io.ReadAll(reader)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
|
||||
func Test_DecryptionReader_SmallReads_WorksCorrectly(t *testing.T) {
|
||||
masterKey := uuid.New().String() + uuid.New().String()
|
||||
backupID := uuid.New()
|
||||
salt, err := GenerateSalt()
|
||||
require.NoError(t, err)
|
||||
nonce, err := GenerateNonce()
|
||||
require.NoError(t, err)
|
||||
|
||||
originalData := []byte("This is test data that will be read in small chunks.")
|
||||
|
||||
var encrypted bytes.Buffer
|
||||
writer, err := NewEncryptionWriter(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = writer.Write(originalData)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = writer.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
reader, err := NewDecryptionReader(&encrypted, masterKey, backupID, salt, nonce)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decrypted []byte
|
||||
buf := make([]byte, 5)
|
||||
for {
|
||||
n, err := reader.Read(buf)
|
||||
if n > 0 {
|
||||
decrypted = append(decrypted, buf[:n]...)
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, originalData, decrypted)
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package encryption
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
const (
|
||||
MagicBytes = "PGRSUS01"
|
||||
MagicBytesLen = 8
|
||||
SaltLen = 32
|
||||
NonceLen = 12
|
||||
ReservedLen = 12
|
||||
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
|
||||
ChunkSize = 32 * 1024
|
||||
PBKDF2Iterations = 100000
|
||||
)
|
||||
|
||||
func DeriveBackupKey(masterKey string, backupID uuid.UUID, salt []byte) ([]byte, error) {
|
||||
if masterKey == "" {
|
||||
return nil, fmt.Errorf("master key cannot be empty")
|
||||
}
|
||||
if len(salt) != SaltLen {
|
||||
return nil, fmt.Errorf("salt must be %d bytes", SaltLen)
|
||||
}
|
||||
|
||||
keyMaterial := []byte(masterKey + backupID.String())
|
||||
|
||||
derivedKey := pbkdf2.Key(keyMaterial, salt, PBKDF2Iterations, 32, sha256.New)
|
||||
|
||||
return derivedKey, nil
|
||||
}
|
||||
|
||||
func GenerateSalt() ([]byte, error) {
|
||||
salt := make([]byte, SaltLen)
|
||||
if _, err := rand.Read(salt); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
return salt, nil
|
||||
}
|
||||
|
||||
func GenerateNonce() ([]byte, error) {
|
||||
nonce := make([]byte, NonceLen)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
return nonce, nil
|
||||
}
|
||||
@@ -6,4 +6,5 @@ const (
|
||||
BackupStatusInProgress BackupStatus = "IN_PROGRESS"
|
||||
BackupStatusCompleted BackupStatus = "COMPLETED"
|
||||
BackupStatusFailed BackupStatus = "FAILED"
|
||||
BackupStatusCanceled BackupStatus = "CANCELED"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
@@ -19,6 +22,7 @@ type NotificationSender interface {
|
||||
|
||||
type CreateBackupUsecase interface {
|
||||
Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -26,7 +30,7 @@ type CreateBackupUsecase interface {
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error
|
||||
) (*usecases_postgresql.BackupMetadata, error)
|
||||
}
|
||||
|
||||
type BackupRemoveListener interface {
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -11,11 +10,8 @@ import (
|
||||
type Backup struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;type:uuid;primaryKey"`
|
||||
|
||||
Database *databases.Database `json:"database" gorm:"foreignKey:DatabaseID"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
|
||||
Storage *storages.Storage `json:"storage" gorm:"foreignKey:StorageID"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
DatabaseID uuid.UUID `json:"databaseId" gorm:"column:database_id;type:uuid;not null"`
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"column:storage_id;type:uuid;not null"`
|
||||
|
||||
Status BackupStatus `json:"status" gorm:"column:status;not null"`
|
||||
FailMessage *string `json:"failMessage" gorm:"column:fail_message"`
|
||||
@@ -24,5 +20,9 @@ type Backup struct {
|
||||
|
||||
BackupDurationMs int64 `json:"backupDurationMs" gorm:"column:backup_duration_ms;default:0"`
|
||||
|
||||
EncryptionSalt *string `json:"-" gorm:"column:encryption_salt"`
|
||||
EncryptionIV *string `json:"-" gorm:"column:encryption_iv"`
|
||||
Encryption backups_config.BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt" gorm:"column:created_at"`
|
||||
}
|
||||
|
||||
@@ -13,18 +13,20 @@ import (
|
||||
type BackupRepository struct{}
|
||||
|
||||
func (r *BackupRepository) Save(backup *Backup) error {
|
||||
if backup.DatabaseID == uuid.Nil || backup.StorageID == uuid.Nil {
|
||||
return errors.New("database ID and storage ID are required")
|
||||
}
|
||||
|
||||
db := storage.GetDb()
|
||||
|
||||
isNew := backup.ID == uuid.Nil
|
||||
if isNew {
|
||||
backup.ID = uuid.New()
|
||||
return db.Create(backup).
|
||||
Omit("Database", "Storage").
|
||||
Error
|
||||
}
|
||||
|
||||
return db.Save(backup).
|
||||
Omit("Database", "Storage").
|
||||
Error
|
||||
}
|
||||
|
||||
@@ -33,8 +35,6 @@ func (r *BackupRepository) FindByDatabaseID(databaseID uuid.UUID) ([]*Backup, er
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -56,8 +56,6 @@ func (r *BackupRepository) FindByDatabaseIDWithLimit(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
@@ -73,8 +71,6 @@ func (r *BackupRepository) FindByStorageID(storageID uuid.UUID) ([]*Backup, erro
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("storage_id = ?", storageID).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -89,8 +85,6 @@ func (r *BackupRepository) FindLastByDatabaseID(databaseID uuid.UUID) (*Backup,
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
First(&backup).Error; err != nil {
|
||||
@@ -109,8 +103,6 @@ func (r *BackupRepository) FindByID(id uuid.UUID) (*Backup, error) {
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("id = ?", id).
|
||||
First(&backup).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -124,8 +116,6 @@ func (r *BackupRepository) FindByStatus(status BackupStatus) ([]*Backup, error)
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -143,8 +133,6 @@ func (r *BackupRepository) FindByStorageIdAndStatus(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("storage_id = ? AND status = ?", storageID, status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -162,8 +150,6 @@ func (r *BackupRepository) FindByDatabaseIdAndStatus(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ? AND status = ?", databaseID, status).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -185,8 +171,6 @@ func (r *BackupRepository) FindBackupsBeforeDate(
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Database").
|
||||
Preload("Storage").
|
||||
Where("database_id = ? AND created_at < ?", databaseID, date).
|
||||
Order("created_at DESC").
|
||||
Find(&backups).Error; err != nil {
|
||||
@@ -195,3 +179,36 @@ func (r *BackupRepository) FindBackupsBeforeDate(
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) FindByDatabaseIDWithPagination(
|
||||
databaseID uuid.UUID,
|
||||
limit, offset int,
|
||||
) ([]*Backup, error) {
|
||||
var backups []*Backup
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Where("database_id = ?", databaseID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&backups).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
}
|
||||
|
||||
func (r *BackupRepository) CountByDatabaseID(databaseID uuid.UUID) (int64, error) {
|
||||
var count int64
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Model(&Backup{}).
|
||||
Where("database_id = ?", databaseID).
|
||||
Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
@@ -1,17 +1,26 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
"slices"
|
||||
"time"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -23,12 +32,18 @@ type BackupService struct {
|
||||
notifierService *notifiers.NotifierService
|
||||
notificationSender NotificationSender
|
||||
backupConfigService *backups_config.BackupConfigService
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor util_encryption.FieldEncryptor
|
||||
|
||||
createBackupUseCase CreateBackupUsecase
|
||||
|
||||
logger *slog.Logger
|
||||
|
||||
backupRemoveListeners []BackupRemoveListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
backupContextManager *BackupContextManager
|
||||
}
|
||||
|
||||
func (s *BackupService) AddBackupRemoveListener(listener BackupRemoveListener) {
|
||||
@@ -62,34 +77,74 @@ func (s *BackupService) MakeBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot create backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canAccess {
|
||||
return errors.New("insufficient permissions to create backup for this database")
|
||||
}
|
||||
|
||||
go s.MakeBackup(databaseID, true)
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Backup manually initiated for database: %s", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackups(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) ([]*Backup, error) {
|
||||
limit, offset int,
|
||||
) (*GetBackupsResponse, error) {
|
||||
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot get backups for database without workspace")
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseID(databaseID)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access backups for this database")
|
||||
}
|
||||
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
backups, err := s.backupRepository.FindByDatabaseIDWithPagination(databaseID, limit, offset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return backups, nil
|
||||
total, err := s.backupRepository.CountByDatabaseID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &GetBackupsResponse{
|
||||
Backups: backups,
|
||||
Total: total,
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(
|
||||
@@ -101,14 +156,37 @@ func (s *BackupService) DeleteBackup(
|
||||
return err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this backup")
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot delete backup for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to delete backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status == BackupStatusInProgress {
|
||||
return errors.New("backup is in progress")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup deleted for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.deleteBackup(backup)
|
||||
}
|
||||
|
||||
@@ -154,10 +232,7 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
|
||||
backup := &Backup{
|
||||
DatabaseID: databaseID,
|
||||
Database: database,
|
||||
|
||||
StorageID: storage.ID,
|
||||
Storage: storage,
|
||||
StorageID: storage.ID,
|
||||
|
||||
Status: BackupStatusInProgress,
|
||||
|
||||
@@ -184,7 +259,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
}
|
||||
}
|
||||
|
||||
err = s.createBackupUseCase.Execute(
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s.backupContextManager.RegisterBackup(backup.ID, cancel)
|
||||
defer s.backupContextManager.UnregisterBackup(backup.ID)
|
||||
|
||||
backupMetadata, err := s.createBackupUseCase.Execute(
|
||||
ctx,
|
||||
backup.ID,
|
||||
backupConfig,
|
||||
database,
|
||||
@@ -193,6 +273,34 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save cancelled backup", "error", err)
|
||||
}
|
||||
|
||||
// Delete partial backup from storage
|
||||
storage, storageErr := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if storageErr == nil {
|
||||
if deleteErr := storage.DeleteFile(s.fieldEncryptor, backup.ID); deleteErr != nil {
|
||||
s.logger.Error(
|
||||
"Failed to delete partial backup file",
|
||||
"backupId",
|
||||
backup.ID,
|
||||
"error",
|
||||
deleteErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
backup.FailMessage = &errMsg
|
||||
backup.Status = BackupStatusFailed
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
@@ -225,6 +333,13 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
backup.Status = BackupStatusCompleted
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
|
||||
// Update backup with encryption metadata if provided
|
||||
if backupMetadata != nil {
|
||||
backup.EncryptionSalt = backupMetadata.EncryptionSalt
|
||||
backup.EncryptionIV = backupMetadata.EncryptionIV
|
||||
backup.Encryption = backupMetadata.Encryption
|
||||
}
|
||||
|
||||
if err := s.backupRepository.Save(backup); err != nil {
|
||||
s.logger.Error("Failed to save backup", "error", err)
|
||||
return
|
||||
@@ -265,6 +380,11 @@ func (s *BackupService) SendBackupNotification(
|
||||
return
|
||||
}
|
||||
|
||||
workspace, err := s.workspaceService.GetWorkspaceByID(*database.WorkspaceID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, notifier := range database.Notifiers {
|
||||
if !slices.Contains(
|
||||
backupConfig.SendNotificationsOn,
|
||||
@@ -276,9 +396,17 @@ func (s *BackupService) SendBackupNotification(
|
||||
title := ""
|
||||
switch notificationType {
|
||||
case backups_config.NotificationBackupFailed:
|
||||
title = fmt.Sprintf("❌ Backup failed for database \"%s\"", database.Name)
|
||||
title = fmt.Sprintf(
|
||||
"❌ Backup failed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
case backups_config.NotificationBackupSuccess:
|
||||
title = fmt.Sprintf("✅ Backup completed for database \"%s\"", database.Name)
|
||||
title = fmt.Sprintf(
|
||||
"✅ Backup completed for database \"%s\" (workspace \"%s\")",
|
||||
database.Name,
|
||||
workspace.Name,
|
||||
)
|
||||
}
|
||||
|
||||
message := ""
|
||||
@@ -319,6 +447,53 @@ func (s *BackupService) GetBackup(backupID uuid.UUID) (*Backup, error) {
|
||||
return s.backupRepository.FindByID(backupID)
|
||||
}
|
||||
|
||||
func (s *BackupService) CancelBackup(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) error {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot cancel backup for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to cancel backup for this database")
|
||||
}
|
||||
|
||||
if backup.Status != BackupStatusInProgress {
|
||||
return errors.New("backup is not in progress")
|
||||
}
|
||||
|
||||
if err := s.backupContextManager.CancelBackup(backupID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup cancelled for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
@@ -328,16 +503,37 @@ func (s *BackupService) GetBackupFile(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if backup.Database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this backup")
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return storage.GetFile(backup.ID)
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot download backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*database.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to download backup for this database")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Backup file downloaded for database: %s (ID: %s)",
|
||||
database.Name,
|
||||
backupID.String(),
|
||||
),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.getBackupReader(backupID)
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
@@ -352,9 +548,12 @@ func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = storage.DeleteFile(backup.ID)
|
||||
err = storage.DeleteFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
// we do not return error here, because sometimes clean up performed
|
||||
// before unavailable storage removal or change - therefore we should
|
||||
// proceed even in case of error
|
||||
s.logger.Error("Failed to delete backup file", "error", err)
|
||||
}
|
||||
|
||||
return s.backupRepository.DeleteByID(backup.ID)
|
||||
@@ -389,3 +588,91 @@ func (s *BackupService) deleteDbBackups(databaseID uuid.UUID) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetBackupReader returns a reader for the backup file
|
||||
// If encrypted, wraps with DecryptionReader
|
||||
func (s *BackupService) getBackupReader(backupID uuid.UUID) (io.ReadCloser, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to find backup: %w", err)
|
||||
}
|
||||
|
||||
storage, err := s.storageService.GetStorageByID(backup.StorageID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get storage: %w", err)
|
||||
}
|
||||
|
||||
fileReader, err := storage.GetFile(s.fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get backup file: %w", err)
|
||||
}
|
||||
|
||||
// If not encrypted, return raw reader
|
||||
if backup.Encryption == backups_config.BackupEncryptionNone {
|
||||
s.logger.Info("Returning non-encrypted backup", "backupId", backupID)
|
||||
return fileReader, nil
|
||||
}
|
||||
|
||||
// Decrypt on-the-fly for encrypted backups
|
||||
if backup.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", err)
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported encryption type: %s", backup.Encryption)
|
||||
}
|
||||
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
if err := fileReader.Close(); err != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", err)
|
||||
}
|
||||
return nil, fmt.Errorf("backup marked as encrypted but missing encryption metadata")
|
||||
}
|
||||
|
||||
// Get master key
|
||||
masterKey, err := s.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
// Decode salt and IV
|
||||
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode salt: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decode IV: %w", err)
|
||||
}
|
||||
|
||||
// Wrap with decrypting reader
|
||||
decryptionReader, err := encryption.NewDecryptionReader(
|
||||
fileReader,
|
||||
masterKey,
|
||||
backup.ID,
|
||||
salt,
|
||||
iv,
|
||||
)
|
||||
if err != nil {
|
||||
if closeErr := fileReader.Close(); closeErr != nil {
|
||||
s.logger.Error("Failed to close file reader", "error", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create decrypting reader: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("Returning encrypted backup with decryption", "backupId", backupID)
|
||||
|
||||
return &decryptionReaderCloser{
|
||||
decryptionReader,
|
||||
fileReader,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -17,15 +26,27 @@ import (
|
||||
)
|
||||
|
||||
func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
router := CreateTestRouter()
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", user, router)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
backups_config.EnableBackupsForTestDatabase(database.ID, storage)
|
||||
|
||||
defer storages.RemoveTestStorage(storage.ID)
|
||||
defer notifiers.RemoveTestNotifier(notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
defer func() {
|
||||
// cleanup backups first
|
||||
backups, _ := backupRepository.FindByDatabaseID(database.ID)
|
||||
for _, backup := range backups {
|
||||
backupRepository.DeleteByID(backup.ID)
|
||||
}
|
||||
|
||||
databases.RemoveTestDatabase(database)
|
||||
time.Sleep(50 * time.Millisecond) // Wait for cascading deletes
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}()
|
||||
|
||||
t.Run("BackupFailed_FailNotificationSent", func(t *testing.T) {
|
||||
mockNotificationSender := &MockNotificationSender{}
|
||||
@@ -36,9 +57,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateFailedBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
// Set up expectations
|
||||
@@ -79,9 +105,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
backupService.MakeBackup(database.ID, true)
|
||||
@@ -99,9 +130,14 @@ func Test_BackupExecuted_NotificationSent(t *testing.T) {
|
||||
notifiers.GetNotifierService(),
|
||||
mockNotificationSender,
|
||||
backups_config.GetBackupConfigService(),
|
||||
encryption_secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
&CreateSuccessBackupUsecase{},
|
||||
logger.GetLogger(),
|
||||
[]BackupRemoveListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
NewBackupContextManager(),
|
||||
}
|
||||
|
||||
// capture arguments
|
||||
@@ -137,6 +173,7 @@ type CreateFailedBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateFailedBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -144,15 +181,16 @@ func (uc *CreateFailedBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
return errors.New("backup failed")
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct {
|
||||
}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -160,7 +198,11 @@ func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
return nil
|
||||
return &usecases_postgresql.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
}, nil
|
||||
}
|
||||
|
||||
20
backend/internal/features/backups/backups/testing.go
Normal file
20
backend/internal/features/backups/backups/testing.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CreateTestRouter() *gin.Engine {
|
||||
return workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
backups_config.GetBackupConfigController(),
|
||||
GetBackupController(),
|
||||
)
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
@@ -14,8 +15,9 @@ type CreateBackupUsecase struct {
|
||||
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database and returns the backup size in MB
|
||||
// Execute creates a backup of the database and returns the backup metadata
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
@@ -23,9 +25,10 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
database,
|
||||
@@ -34,5 +37,5 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
)
|
||||
}
|
||||
|
||||
return errors.New("database type not supported")
|
||||
return nil, errors.New("database type not supported")
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package usecases_postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,21 +15,39 @@ import (
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
backupTimeout = 23 * time.Hour
|
||||
shutdownCheckInterval = 1 * time.Second
|
||||
copyBufferSize = 32 * 1024
|
||||
progressReportIntervalMB = 1.0
|
||||
pgConnectTimeout = 30
|
||||
compressionLevel = 5
|
||||
exitCodeAccessViolation = -1073741819
|
||||
exitCodeGenericError = 1
|
||||
exitCodeConnectionError = 2
|
||||
)
|
||||
|
||||
type CreatePostgresqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
logger *slog.Logger
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
@@ -36,7 +55,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) error {
|
||||
) (*BackupMetadata, error) {
|
||||
uc.logger.Info(
|
||||
"Creating PostgreSQL backup via pg_dump custom format",
|
||||
"databaseId",
|
||||
@@ -46,41 +65,28 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
pg := db.Postgresql
|
||||
|
||||
if pg == nil {
|
||||
return fmt.Errorf("postgresql database configuration is required for pg_dump backups")
|
||||
return nil, fmt.Errorf("postgresql database configuration is required for pg_dump backups")
|
||||
}
|
||||
|
||||
if pg.Database == nil || *pg.Database == "" {
|
||||
return fmt.Errorf("database name is required for pg_dump backups")
|
||||
return nil, fmt.Errorf("database name is required for pg_dump backups")
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"-Fc", // custom format with built-in compression
|
||||
"--no-password", // Use environment variable for password, prevent prompts
|
||||
"-h", pg.Host,
|
||||
"-p", strconv.Itoa(pg.Port),
|
||||
"-U", pg.Username,
|
||||
"-d", *pg.Database,
|
||||
"--verbose", // Add verbose output to help with debugging
|
||||
}
|
||||
args := uc.buildPgDumpArgs(pg)
|
||||
|
||||
// Use zstd compression level 5 for PostgreSQL 15+ (better compression and speed)
|
||||
// Fall back to gzip compression level 5 for older versions
|
||||
if pg.Version == tools.PostgresqlVersion13 || pg.Version == tools.PostgresqlVersion14 ||
|
||||
pg.Version == tools.PostgresqlVersion15 {
|
||||
args = append(args, "-Z", "5")
|
||||
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", pg.Version)
|
||||
} else {
|
||||
args = append(args, "--compress=zstd:5")
|
||||
uc.logger.Info("Using zstd compression level 5", "version", pg.Version)
|
||||
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, pg.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
|
||||
}
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
tools.GetPostgresqlExecutable(
|
||||
@@ -90,7 +96,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
config.GetEnv().PostgresesInstallDir,
|
||||
),
|
||||
args,
|
||||
pg.Password,
|
||||
decryptedPassword,
|
||||
storage,
|
||||
db,
|
||||
backupProgressListener,
|
||||
@@ -99,6 +105,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
|
||||
// streamToStorage streams pg_dump output directly to storage
|
||||
func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
pgBin string,
|
||||
@@ -107,36 +114,15 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storage *storages.Storage,
|
||||
db *databases.Database,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) error {
|
||||
) (*BackupMetadata, error) {
|
||||
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
|
||||
|
||||
// if backup not fit into 23 hours, Postgresus
|
||||
// seems not to work for such database size
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 23*time.Hour)
|
||||
ctx, cancel := uc.createBackupContext(parentCtx)
|
||||
defer cancel()
|
||||
|
||||
// Monitor for shutdown and cancel context if needed
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Create temporary .pgpass file as a more reliable alternative to PGPASSWORD
|
||||
pgpassFile, err := uc.createTempPgpassFile(db.Postgresql, password)
|
||||
pgpassFile, err := uc.setupPgpassFile(db.Postgresql, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create temporary .pgpass file: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if pgpassFile != "" {
|
||||
@@ -144,87 +130,21 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
}
|
||||
}()
|
||||
|
||||
// Verify .pgpass file was created successfully
|
||||
if pgpassFile == "" {
|
||||
return fmt.Errorf("temporary .pgpass file was not created")
|
||||
}
|
||||
|
||||
// Verify .pgpass file was created correctly
|
||||
if info, err := os.Stat(pgpassFile); err == nil {
|
||||
uc.logger.Info("Temporary .pgpass file created successfully",
|
||||
"pgpassFile", pgpassFile,
|
||||
"size", info.Size(),
|
||||
"mode", info.Mode(),
|
||||
)
|
||||
} else {
|
||||
return fmt.Errorf("failed to verify .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, pgBin, args...)
|
||||
uc.logger.Info("Executing PostgreSQL backup command", "command", cmd.String())
|
||||
|
||||
// Start with system environment variables to preserve Windows PATH, SystemRoot, etc.
|
||||
cmd.Env = os.Environ()
|
||||
|
||||
// Use the .pgpass file for authentication
|
||||
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
|
||||
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
|
||||
|
||||
// Debug password setup (without exposing the actual password)
|
||||
uc.logger.Info("Setting up PostgreSQL environment",
|
||||
"passwordLength", len(password),
|
||||
"passwordEmpty", password == "",
|
||||
"pgBin", pgBin,
|
||||
"usingPgpassFile", true,
|
||||
"parallelJobs", backupConfig.CpuCount,
|
||||
)
|
||||
|
||||
// Add PostgreSQL-specific environment variables
|
||||
cmd.Env = append(cmd.Env, "PGCLIENTENCODING=UTF8")
|
||||
cmd.Env = append(cmd.Env, "PGCONNECT_TIMEOUT=30")
|
||||
|
||||
// Add encoding-related environment variables to handle character encoding issues
|
||||
cmd.Env = append(cmd.Env, "LC_ALL=C.UTF-8")
|
||||
cmd.Env = append(cmd.Env, "LANG=C.UTF-8")
|
||||
|
||||
// Add PostgreSQL-specific encoding settings
|
||||
cmd.Env = append(cmd.Env, "PGOPTIONS=--client-encoding=UTF8")
|
||||
|
||||
shouldRequireSSL := db.Postgresql.IsHttps
|
||||
|
||||
// Require SSL when explicitly configured
|
||||
if shouldRequireSSL {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
|
||||
uc.logger.Info("Using required SSL mode", "configuredHttps", db.Postgresql.IsHttps)
|
||||
} else {
|
||||
// SSL not explicitly required, but prefer it if available
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
|
||||
uc.logger.Info("Using preferred SSL mode", "configuredHttps", db.Postgresql.IsHttps)
|
||||
}
|
||||
|
||||
// Set other SSL parameters to avoid certificate issues
|
||||
cmd.Env = append(cmd.Env, "PGSSLCERT=") // No client certificate
|
||||
cmd.Env = append(cmd.Env, "PGSSLKEY=") // No client key
|
||||
cmd.Env = append(cmd.Env, "PGSSLROOTCERT=") // No root certificate verification
|
||||
cmd.Env = append(cmd.Env, "PGSSLCRL=") // No certificate revocation list
|
||||
|
||||
// Verify executable exists and is accessible
|
||||
if _, err := exec.LookPath(pgBin); err != nil {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL executable not found or not accessible: %s - %w",
|
||||
pgBin,
|
||||
err,
|
||||
)
|
||||
if err := uc.setupPgEnvironment(cmd, pgpassFile, db.Postgresql.IsHttps, password, backupConfig.CpuCount, pgBin); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pgStdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
return nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
pgStderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
return nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
// Capture stderr in a separate goroutine to ensure we don't miss any error output
|
||||
@@ -234,23 +154,31 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
stderrCh <- stderrOutput
|
||||
}()
|
||||
|
||||
// A pipe connecting pg_dump output → storage
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
// Create a counting writer to track bytes
|
||||
countingWriter := &CountingWriter{writer: storageWriter}
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
countingWriter := &CountingWriter{writer: finalWriter}
|
||||
|
||||
// The backup ID becomes the object key / filename in storage
|
||||
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErrCh <- storage.SaveFile(uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
// Start pg_dump
|
||||
if err = cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
|
||||
return nil, fmt.Errorf("start %s: %w", filepath.Base(pgBin), err)
|
||||
}
|
||||
|
||||
// Copy pg output directly to storage with shutdown checks
|
||||
@@ -272,23 +200,17 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
bytesWritten := <-bytesWrittenCh
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Check for shutdown before finalizing
|
||||
if config.IsShouldShutdown() {
|
||||
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
|
||||
if err := pipeWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close counting writer", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
<-saveErrCh // Wait for storage to finish
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
// Check for shutdown or cancellation before finalizing
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
|
||||
return nil, uc.checkCancellationReason()
|
||||
default:
|
||||
}
|
||||
|
||||
// Close the pipe writer to signal end of data
|
||||
if pipeWriter, ok := countingWriter.writer.(*io.PipeWriter); ok {
|
||||
if err := pipeWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close counting writer", "error", err)
|
||||
}
|
||||
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
|
||||
<-saveErrCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait until storage ends reading
|
||||
@@ -303,134 +225,34 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
|
||||
switch {
|
||||
case waitErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enhanced error handling for PostgreSQL connection and SSL issues
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
// Check for specific PostgreSQL error patterns
|
||||
if exitErr, ok := waitErr.(*exec.ExitError); ok {
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
// Enhanced debugging for exit status 1 with empty stderr
|
||||
if exitCode == 1 && strings.TrimSpace(stderrStr) == "" {
|
||||
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"env_vars", []string{
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT=30",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
},
|
||||
)
|
||||
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s failed with exit status 1 but provided no error details. "+
|
||||
"This often indicates: "+
|
||||
"1) Connection timeout or refused connection, "+
|
||||
"2) Authentication failure with incorrect credentials, "+
|
||||
"3) Database does not exist, "+
|
||||
"4) Network connectivity issues, "+
|
||||
"5) PostgreSQL server not running. "+
|
||||
"Command executed: %s %s",
|
||||
filepath.Base(pgBin),
|
||||
pgBin,
|
||||
strings.Join(args, " "),
|
||||
)
|
||||
} else if exitCode == -1073741819 { // 0xC0000005 in decimal
|
||||
uc.logger.Error("PostgreSQL tool crashed with access violation",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"exitCode", fmt.Sprintf("0x%X", uint32(exitCode)),
|
||||
)
|
||||
|
||||
errorMsg = fmt.Sprintf(
|
||||
"%s crashed with access violation (0xC0000005). This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
stderrStr,
|
||||
)
|
||||
} else if exitCode == 1 || exitCode == 2 {
|
||||
// Check for common connection and authentication issues
|
||||
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection rejected by server configuration (pg_hba.conf). The server may not allow connections from your IP address or may require different authentication settings. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "no password supplied") || containsIgnoreCase(stderrStr, "fe_sendauth") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed - no password supplied. "+
|
||||
"PGPASSWORD environment variable may not be working correctly on this system. "+
|
||||
"Password length: %d, Password empty: %v. "+
|
||||
"Consider using a .pgpass file as an alternative. stderr: %s",
|
||||
len(password),
|
||||
password == "",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL SSL connection failed. The server may require SSL encryption or have SSL configuration issues. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection refused. Check if the server is running and accessible from your network. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "authentication") || containsIgnoreCase(stderrStr, "password") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL authentication failed. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
} else if containsIgnoreCase(stderrStr, "timeout") {
|
||||
errorMsg = fmt.Sprintf(
|
||||
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
return nil, uc.buildPgDumpErrorMessage(waitErr, stderrOutput, pgBin, args, password)
|
||||
case copyErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fmt.Errorf("copy to storage: %w", copyErr)
|
||||
return nil, fmt.Errorf("copy to storage: %w", copyErr)
|
||||
case saveErr != nil:
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
if err := uc.checkCancellation(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return fmt.Errorf("save to storage: %w", saveErr)
|
||||
return nil, fmt.Errorf("save to storage: %w", saveErr)
|
||||
}
|
||||
|
||||
return nil
|
||||
return &backupMetadata, nil
|
||||
}
|
||||
|
||||
// copyWithShutdownCheck copies data from src to dst while checking for shutdown
|
||||
func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (int64, error) {
|
||||
buf := make([]byte, 32*1024) // 32KB buffer
|
||||
buf := make([]byte, copyBufferSize)
|
||||
var totalBytesWritten int64
|
||||
|
||||
// Progress reporting interval - report every 1MB of data
|
||||
var lastReportedMB float64
|
||||
const reportIntervalMB = 1.0
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -463,12 +285,9 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
|
||||
// Report progress based on total size
|
||||
if backupProgressListener != nil {
|
||||
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
|
||||
|
||||
// Only report if we've written at least 1MB more data than last report
|
||||
if currentSizeMB >= lastReportedMB+reportIntervalMB {
|
||||
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
|
||||
backupProgressListener(currentSizeMB)
|
||||
lastReportedMB = currentSizeMB
|
||||
}
|
||||
@@ -479,7 +298,6 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -487,12 +305,412 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
// containsIgnoreCase checks if a string contains a substring, ignoring case
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlDatabase) []string {
|
||||
args := []string{
|
||||
"-Fc",
|
||||
"--no-password",
|
||||
"-h", pg.Host,
|
||||
"-p", strconv.Itoa(pg.Port),
|
||||
"-U", pg.Username,
|
||||
"-d", *pg.Database,
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
compressionArgs := uc.getCompressionArgs(pg.Version)
|
||||
return append(args, compressionArgs...)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) getCompressionArgs(
|
||||
version tools.PostgresqlVersion,
|
||||
) []string {
|
||||
if uc.isOlderPostgresVersion(version) {
|
||||
uc.logger.Info("Using gzip compression level 5 (zstd not available)", "version", version)
|
||||
return []string{"-Z", strconv.Itoa(compressionLevel)}
|
||||
}
|
||||
|
||||
uc.logger.Info("Using zstd compression level 5", "version", version)
|
||||
return []string{fmt.Sprintf("--compress=zstd:%d", compressionLevel)}
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) isOlderPostgresVersion(
|
||||
version tools.PostgresqlVersion,
|
||||
) bool {
|
||||
return version == tools.PostgresqlVersion12 ||
|
||||
version == tools.PostgresqlVersion13 ||
|
||||
version == tools.PostgresqlVersion14 ||
|
||||
version == tools.PostgresqlVersion15
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
|
||||
parentCtx context.Context,
|
||||
) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(shutdownCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupPgpassFile(
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
pgpassFile, err := uc.createTempPgpassFile(pgConfig, password)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
if pgpassFile == "" {
|
||||
return "", fmt.Errorf("temporary .pgpass file was not created")
|
||||
}
|
||||
|
||||
if info, err := os.Stat(pgpassFile); err == nil {
|
||||
uc.logger.Info("Temporary .pgpass file created successfully",
|
||||
"pgpassFile", pgpassFile,
|
||||
"size", info.Size(),
|
||||
"mode", info.Mode(),
|
||||
)
|
||||
} else {
|
||||
return "", fmt.Errorf("failed to verify .pgpass file: %w", err)
|
||||
}
|
||||
|
||||
return pgpassFile, nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupPgEnvironment(
|
||||
cmd *exec.Cmd,
|
||||
pgpassFile string,
|
||||
shouldRequireSSL bool,
|
||||
password string,
|
||||
cpuCount int,
|
||||
pgBin string,
|
||||
) error {
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env, "PGPASSFILE="+pgpassFile)
|
||||
|
||||
uc.logger.Info("Using temporary .pgpass file for authentication", "pgpassFile", pgpassFile)
|
||||
uc.logger.Info("Setting up PostgreSQL environment",
|
||||
"passwordLength", len(password),
|
||||
"passwordEmpty", password == "",
|
||||
"pgBin", pgBin,
|
||||
"usingPgpassFile", true,
|
||||
"parallelJobs", cpuCount,
|
||||
)
|
||||
|
||||
cmd.Env = append(cmd.Env,
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT="+strconv.Itoa(pgConnectTimeout),
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
)
|
||||
|
||||
if shouldRequireSSL {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=require")
|
||||
uc.logger.Info("Using required SSL mode", "configuredHttps", shouldRequireSSL)
|
||||
} else {
|
||||
cmd.Env = append(cmd.Env, "PGSSLMODE=prefer")
|
||||
uc.logger.Info("Using preferred SSL mode", "configuredHttps", shouldRequireSSL)
|
||||
}
|
||||
|
||||
cmd.Env = append(cmd.Env,
|
||||
"PGSSLCERT=",
|
||||
"PGSSLKEY=",
|
||||
"PGSSLROOTCERT=",
|
||||
"PGSSLCRL=",
|
||||
)
|
||||
|
||||
if _, err := exec.LookPath(pgBin); err != nil {
|
||||
return fmt.Errorf("PostgreSQL executable not found or not accessible: %s - %w", pgBin, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
|
||||
metadata := BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
|
||||
return storageWriter, nil, metadata, nil
|
||||
}
|
||||
|
||||
salt, err := backup_encryption.GenerateSalt()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := backup_encryption.GenerateNonce()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
||||
storageWriter,
|
||||
masterKey,
|
||||
backupID,
|
||||
salt,
|
||||
nonce,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
||||
}
|
||||
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
metadata.EncryptionSalt = &saltBase64
|
||||
metadata.EncryptionIV = &nonceBase64
|
||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||
|
||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||
return encWriter, encWriter, metadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) cleanupOnCancellation(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
saveErrCh chan error,
|
||||
) {
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
if closeErr := encryptionWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close encrypting writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
|
||||
}
|
||||
|
||||
<-saveErrCh
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) closeWriters(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
) error {
|
||||
encryptionCloseErrCh := make(chan error, 1)
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
closeErr := encryptionWriter.Close()
|
||||
if closeErr != nil {
|
||||
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
|
||||
}
|
||||
encryptionCloseErrCh <- closeErr
|
||||
}()
|
||||
} else {
|
||||
encryptionCloseErrCh <- nil
|
||||
}
|
||||
|
||||
encryptionCloseErr := <-encryptionCloseErrCh
|
||||
if encryptionCloseErr != nil {
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
|
||||
}
|
||||
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) checkCancellation(ctx context.Context) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) checkCancellationReason() error {
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) buildPgDumpErrorMessage(
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
pgBin string,
|
||||
args []string,
|
||||
password string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf("%s failed: %v – stderr: %s", filepath.Base(pgBin), waitErr, stderrStr)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
if exitCode == exitCodeGenericError && strings.TrimSpace(stderrStr) == "" {
|
||||
return uc.handleExitCode1NoStderr(pgBin, args)
|
||||
}
|
||||
|
||||
if exitCode == exitCodeAccessViolation {
|
||||
return uc.handleAccessViolation(pgBin, stderrStr)
|
||||
}
|
||||
|
||||
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
|
||||
return uc.handleConnectionErrors(stderrStr, password)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleExitCode1NoStderr(
|
||||
pgBin string,
|
||||
args []string,
|
||||
) error {
|
||||
uc.logger.Error("pg_dump failed with exit status 1 but no stderr output",
|
||||
"pgBin", pgBin,
|
||||
"args", args,
|
||||
"env_vars", []string{
|
||||
"PGCLIENTENCODING=UTF8",
|
||||
"PGCONNECT_TIMEOUT=" + strconv.Itoa(pgConnectTimeout),
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
"PGOPTIONS=--client-encoding=UTF8",
|
||||
},
|
||||
)
|
||||
|
||||
return fmt.Errorf(
|
||||
"%s failed with exit status 1 but provided no error details. "+
|
||||
"This often indicates: "+
|
||||
"1) Connection timeout or refused connection, "+
|
||||
"2) Authentication failure with incorrect credentials, "+
|
||||
"3) Database does not exist, "+
|
||||
"4) Network connectivity issues, "+
|
||||
"5) PostgreSQL server not running. "+
|
||||
"Command executed: %s %s",
|
||||
filepath.Base(pgBin),
|
||||
pgBin,
|
||||
strings.Join(args, " "),
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleAccessViolation(
|
||||
pgBin string,
|
||||
stderrStr string,
|
||||
) error {
|
||||
uc.logger.Error("PostgreSQL tool crashed with access violation",
|
||||
"pgBin", pgBin,
|
||||
"exitCode", "0xC0000005",
|
||||
)
|
||||
|
||||
return fmt.Errorf(
|
||||
"%s crashed with access violation (0xC0000005). "+
|
||||
"This may indicate incompatible PostgreSQL version, corrupted installation, or connection issues. "+
|
||||
"stderr: %s",
|
||||
filepath.Base(pgBin),
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) handleConnectionErrors(
|
||||
stderrStr string,
|
||||
password string,
|
||||
) error {
|
||||
if containsIgnoreCase(stderrStr, "pg_hba.conf") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection rejected by server configuration (pg_hba.conf). "+
|
||||
"The server may not allow connections from your IP address or may require different authentication settings. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "no password supplied") ||
|
||||
containsIgnoreCase(stderrStr, "fe_sendauth") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL authentication failed - no password supplied. "+
|
||||
"PGPASSWORD environment variable may not be working correctly on this system. "+
|
||||
"Password length: %d, Password empty: %v. "+
|
||||
"Consider using a .pgpass file as an alternative. "+
|
||||
"stderr: %s",
|
||||
len(password),
|
||||
password == "",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") && containsIgnoreCase(stderrStr, "connection") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL SSL connection failed. "+
|
||||
"The server may require SSL encryption or have SSL configuration issues. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "connection") && containsIgnoreCase(stderrStr, "refused") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection refused. "+
|
||||
"Check if the server is running and accessible from your network. "+
|
||||
"stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "authentication") ||
|
||||
containsIgnoreCase(stderrStr, "password") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL authentication failed. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"PostgreSQL connection timeout. The server may be unreachable or overloaded. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("PostgreSQL connection or authentication error. stderr: %s", stderrStr)
|
||||
}
|
||||
|
||||
// createTempPgpassFile creates a temporary .pgpass file with the given password
|
||||
func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
password string,
|
||||
@@ -508,7 +726,6 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
password,
|
||||
)
|
||||
|
||||
// it always create unique directory like /tmp/pgpass-1234567890
|
||||
tempDir, err := os.MkdirTemp("", "pgpass")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
@@ -522,3 +739,7 @@ func (uc *CreatePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
|
||||
return pgpassFile, nil
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var createPostgresqlBackupUsecase = &CreatePostgresqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
func GetCreatePostgresqlBackupUsecase() *CreatePostgresqlBackupUsecase {
|
||||
|
||||
@@ -0,0 +1,15 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import backups_config "postgresus-backend/internal/features/backups/config"
|
||||
|
||||
type EncryptionMetadata struct {
|
||||
Salt string
|
||||
IV string
|
||||
Encryption backups_config.BackupEncryption
|
||||
}
|
||||
|
||||
type BackupMetadata struct {
|
||||
EncryptionSalt *string
|
||||
EncryptionIV *string
|
||||
Encryption backups_config.BackupEncryption
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package backups_config
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type BackupConfigController struct {
|
||||
backupConfigService *BackupConfigService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -21,35 +20,29 @@ func (c *BackupConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
|
||||
// SaveBackupConfig
|
||||
// @Summary Save backup configuration
|
||||
// @Description Save or update backup configuration for a database
|
||||
// @Description Save or update backup configuration for a database. Encryption can be set to NONE (no encryption) or ENCRYPTED (AES-256-GCM encryption).
|
||||
// @Tags backup-configs
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body BackupConfig true "Backup configuration data"
|
||||
// @Success 200 {object} BackupConfig
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Param request body BackupConfig true "Backup configuration data (encryption field: NONE or ENCRYPTED)"
|
||||
// @Success 200 {object} BackupConfig "Returns the saved backup configuration including encryption settings"
|
||||
// @Failure 400 {object} map[string]string "Invalid encryption value or other validation errors"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 500 {object} map[string]string "Internal server error"
|
||||
// @Router /backup-configs/save [post]
|
||||
func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var requestDTO BackupConfig
|
||||
if err := ctx.ShouldBindJSON(&requestDTO); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
// make sure we rely on full .Storage object
|
||||
requestDTO.StorageID = nil
|
||||
|
||||
@@ -64,40 +57,28 @@ func (c *BackupConfigController) SaveBackupConfig(ctx *gin.Context) {
|
||||
|
||||
// GetBackupConfigByDbID
|
||||
// @Summary Get backup configuration by database ID
|
||||
// @Description Get backup configuration for a specific database
|
||||
// @Description Get backup configuration for a specific database including encryption settings (NONE or ENCRYPTED)
|
||||
// @Tags backup-configs
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 200 {object} BackupConfig
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 404
|
||||
// @Success 200 {object} BackupConfig "Returns backup configuration with encryption field"
|
||||
// @Failure 400 {object} map[string]string "Invalid database ID"
|
||||
// @Failure 401 {object} map[string]string "User not authenticated"
|
||||
// @Failure 404 {object} map[string]string "Backup configuration not found"
|
||||
// @Router /backup-configs/database/{id} [get]
|
||||
func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
backupConfig, err := c.backupConfigService.GetBackupConfigByDbIdWithAuth(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusNotFound, gin.H{"error": "backup configuration not found"})
|
||||
@@ -119,24 +100,18 @@ func (c *BackupConfigController) GetBackupConfigByDbID(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /backup-configs/storage/{id}/is-using [get]
|
||||
func (c *BackupConfigController) IsStorageUsing(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid storage ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
isUsing, err := c.backupConfigService.IsStorageUsing(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
493
backend/internal/features/backups/config/controller_test.go
Normal file
493
backend/internal/features/backups/config/controller_test.go
Normal file
@@ -0,0 +1,493 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
"postgresus-backend/internal/util/period"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetBackupConfigController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot save backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can save backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
testResp := test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, 2, response.CpuCount)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+nonMember.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetBackupConfigByDbID_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get backup config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get backup config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
testResp := test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
if tt.expectSuccess {
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
} else {
|
||||
assert.Contains(t, string(testResp.Body), "backup configuration not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetBackupConfigByDbID_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/database/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.False(t, response.IsBackupsEnabled)
|
||||
assert.Equal(t, period.PeriodWeek, response.StorePeriod)
|
||||
assert.Equal(t, 1, response.CpuCount)
|
||||
assert.True(t, response.IsRetryIfFailed)
|
||||
assert.Equal(t, 3, response.MaxFailedTriesCount)
|
||||
assert.NotNil(t, response.BackupInterval)
|
||||
}
|
||||
|
||||
func Test_IsStorageUsing_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
isStorageOwner bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "storage owner can check storage usage",
|
||||
isStorageOwner: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-storage-owner cannot check storage usage",
|
||||
isStorageOwner: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
storageOwner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"Test Workspace",
|
||||
storageOwner,
|
||||
router,
|
||||
)
|
||||
storage := createTestStorage(workspace.ID)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isStorageOwner {
|
||||
testUserToken = storageOwner.Token
|
||||
} else {
|
||||
otherUser := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = otherUser.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response map[string]bool
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
isUsing, exists := response["isUsing"]
|
||||
assert.True(t, exists)
|
||||
assert.False(t, isUsing)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/storage/"+storage.ID.String()+"/is-using",
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "error")
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithEncryptionNone_ConfigSaved(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.Equal(t, BackupEncryptionNone, response.Encryption)
|
||||
}
|
||||
|
||||
func Test_SaveBackupConfig_WithEncryptionEncrypted_ConfigSaved(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
timeOfDay := "04:00"
|
||||
request := BackupConfig{
|
||||
DatabaseID: database.ID,
|
||||
IsBackupsEnabled: true,
|
||||
StorePeriod: period.PeriodWeek,
|
||||
BackupInterval: &intervals.Interval{
|
||||
Interval: intervals.IntervalDaily,
|
||||
TimeOfDay: &timeOfDay,
|
||||
},
|
||||
SendNotificationsOn: []BackupNotificationType{
|
||||
NotificationBackupFailed,
|
||||
},
|
||||
CpuCount: 2,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionEncrypted,
|
||||
}
|
||||
|
||||
var response BackupConfig
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/backup-configs/save",
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.Equal(t, BackupEncryptionEncrypted, response.Encryption)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestStorage(workspaceID uuid.UUID) *storages.Storage {
|
||||
return storages.CreateTestStorage(workspaceID)
|
||||
}
|
||||
@@ -3,7 +3,7 @@ package backups_config
|
||||
import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
)
|
||||
|
||||
var backupConfigRepository = &BackupConfigRepository{}
|
||||
@@ -11,11 +11,11 @@ var backupConfigService = &BackupConfigService{
|
||||
backupConfigRepository,
|
||||
databases.GetDatabaseService(),
|
||||
storages.GetStorageService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
nil,
|
||||
}
|
||||
var backupConfigController = &BackupConfigController{
|
||||
backupConfigService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetBackupConfigController() *BackupConfigController {
|
||||
|
||||
@@ -6,3 +6,10 @@ const (
|
||||
NotificationBackupFailed BackupNotificationType = "BACKUP_FAILED"
|
||||
NotificationBackupSuccess BackupNotificationType = "BACKUP_SUCCESS"
|
||||
)
|
||||
|
||||
type BackupEncryption string
|
||||
|
||||
const (
|
||||
BackupEncryptionNone BackupEncryption = "NONE"
|
||||
BackupEncryptionEncrypted BackupEncryption = "ENCRYPTED"
|
||||
)
|
||||
|
||||
@@ -31,6 +31,8 @@ type BackupConfig struct {
|
||||
MaxFailedTriesCount int `json:"maxFailedTriesCount" gorm:"column:max_failed_tries_count;type:int;not null"`
|
||||
|
||||
CpuCount int `json:"cpuCount" gorm:"type:int;not null"`
|
||||
|
||||
Encryption BackupEncryption `json:"encryption" gorm:"column:encryption;type:text;not null;default:'NONE'"`
|
||||
}
|
||||
|
||||
func (h *BackupConfig) TableName() string {
|
||||
@@ -88,5 +90,26 @@ func (b *BackupConfig) Validate() error {
|
||||
return errors.New("max failed tries count must be greater than 0")
|
||||
}
|
||||
|
||||
if b.Encryption != "" && b.Encryption != BackupEncryptionNone &&
|
||||
b.Encryption != BackupEncryptionEncrypted {
|
||||
return errors.New("encryption must be NONE or ENCRYPTED")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *BackupConfig) Copy(newDatabaseID uuid.UUID) *BackupConfig {
|
||||
return &BackupConfig{
|
||||
DatabaseID: newDatabaseID,
|
||||
IsBackupsEnabled: b.IsBackupsEnabled,
|
||||
StorePeriod: b.StorePeriod,
|
||||
BackupIntervalID: uuid.Nil,
|
||||
BackupInterval: b.BackupInterval.Copy(),
|
||||
StorageID: b.StorageID,
|
||||
SendNotificationsOn: b.SendNotificationsOn,
|
||||
IsRetryIfFailed: b.IsRetryIfFailed,
|
||||
MaxFailedTriesCount: b.MaxFailedTriesCount,
|
||||
CpuCount: b.CpuCount,
|
||||
Encryption: b.Encryption,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
package backups_config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/intervals"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/period"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,6 +17,7 @@ type BackupConfigService struct {
|
||||
backupConfigRepository *BackupConfigRepository
|
||||
databaseService *databases.DatabaseService
|
||||
storageService *storages.StorageService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
|
||||
dbStorageChangeListener BackupConfigStorageChangeListener
|
||||
}
|
||||
@@ -32,11 +36,23 @@ func (s *BackupConfigService) SaveBackupConfigWithAuth(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
|
||||
database, err := s.databaseService.GetDatabase(user, backupConfig.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot save backup config for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to modify backup configuration")
|
||||
}
|
||||
|
||||
return s.SaveBackupConfig(backupConfig)
|
||||
}
|
||||
|
||||
@@ -66,19 +82,6 @@ func (s *BackupConfigService) SaveBackupConfig(
|
||||
}
|
||||
}
|
||||
|
||||
if !backupConfig.IsBackupsEnabled && existingConfig.StorageID != nil {
|
||||
if err := s.dbStorageChangeListener.OnBeforeBackupsStorageChange(
|
||||
backupConfig.DatabaseID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// we clear storage for disabled backups to allow
|
||||
// storage removal for unused storages
|
||||
backupConfig.Storage = nil
|
||||
backupConfig.StorageID = nil
|
||||
}
|
||||
|
||||
return s.backupConfigRepository.Save(backupConfig)
|
||||
}
|
||||
|
||||
@@ -130,6 +133,24 @@ func (s *BackupConfigService) GetBackupConfigsWithEnabledBackups() ([]*BackupCon
|
||||
return s.backupConfigRepository.GetWithEnabledBackups()
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) OnDatabaseCopied(originalDatabaseID, newDatabaseID uuid.UUID) {
|
||||
originalConfig, err := s.GetBackupConfigByDbId(originalDatabaseID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
newConfig := originalConfig.Copy(newDatabaseID)
|
||||
|
||||
_, err = s.SaveBackupConfig(newConfig)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) CreateDisabledBackupConfig(databaseID uuid.UUID) error {
|
||||
return s.initializeDefaultConfig(databaseID)
|
||||
}
|
||||
|
||||
func (s *BackupConfigService) initializeDefaultConfig(
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
@@ -150,6 +171,7 @@ func (s *BackupConfigService) initializeDefaultConfig(
|
||||
CpuCount: 1,
|
||||
IsRetryIfFailed: true,
|
||||
MaxFailedTriesCount: 3,
|
||||
Encryption: BackupEncryptionNone,
|
||||
})
|
||||
|
||||
return err
|
||||
|
||||
@@ -2,15 +2,18 @@ package databases
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DatabaseController struct {
|
||||
databaseService *DatabaseService
|
||||
userService *users.UserService
|
||||
databaseService *DatabaseService
|
||||
userService *users_services.UserService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -21,42 +24,43 @@ func (c *DatabaseController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
router.GET("/databases", c.GetDatabases)
|
||||
router.POST("/databases/:id/test-connection", c.TestDatabaseConnection)
|
||||
router.POST("/databases/test-connection-direct", c.TestDatabaseConnectionDirect)
|
||||
router.POST("/databases/:id/copy", c.CopyDatabase)
|
||||
router.GET("/databases/notifier/:id/is-using", c.IsNotifierUsing)
|
||||
|
||||
router.POST("/databases/is-readonly", c.IsUserReadOnly)
|
||||
router.POST("/databases/create-readonly-user", c.CreateReadOnlyUser)
|
||||
}
|
||||
|
||||
// CreateDatabase
|
||||
// @Summary Create a new database
|
||||
// @Description Create a new database configuration
|
||||
// @Description Create a new database configuration in a workspace
|
||||
// @Tags databases
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body Database true "Database creation data"
|
||||
// @Param request body Database true "Database creation data with workspaceId"
|
||||
// @Success 201 {object} Database
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases/create [post]
|
||||
func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
if request.WorkspaceID == nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
database, err := c.databaseService.CreateDatabase(user, &request)
|
||||
database, err := c.databaseService.CreateDatabase(user, *request.WorkspaceID, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -78,24 +82,18 @@ func (c *DatabaseController) CreateDatabase(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/update [post]
|
||||
func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.UpdateDatabase(user, &request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -115,24 +113,18 @@ func (c *DatabaseController) UpdateDatabase(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/{id} [delete]
|
||||
func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.DeleteDatabase(user, id); err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -152,24 +144,18 @@ func (c *DatabaseController) DeleteDatabase(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /databases/{id} [get]
|
||||
func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
database, err := c.databaseService.GetDatabase(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
@@ -180,30 +166,38 @@ func (c *DatabaseController) GetDatabase(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
// GetDatabases
|
||||
// @Summary Get databases
|
||||
// @Description Get all databases for the authenticated user
|
||||
// @Summary Get databases by workspace
|
||||
// @Description Get all databases for a specific workspace
|
||||
// @Tags databases
|
||||
// @Produce json
|
||||
// @Param workspace_id query string true "Workspace ID"
|
||||
// @Success 200 {array} Database
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases [get]
|
||||
func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
workspaceIDStr := ctx.Query("workspace_id")
|
||||
if workspaceIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
databases, err := c.databaseService.GetDatabasesByUser(user)
|
||||
workspaceID, err := uuid.Parse(workspaceIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
|
||||
return
|
||||
}
|
||||
|
||||
databases, err := c.databaseService.GetDatabasesByWorkspace(user, workspaceID)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -221,24 +215,18 @@ func (c *DatabaseController) GetDatabases(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/{id}/test-connection [post]
|
||||
func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.databaseService.TestDatabaseConnection(user, id); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -258,27 +246,18 @@ func (c *DatabaseController) TestDatabaseConnection(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /databases/test-connection-direct [post]
|
||||
func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
|
||||
_, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Set user ID for validation purposes
|
||||
request.UserID = user.ID
|
||||
|
||||
if err := c.databaseService.TestDatabaseConnectionDirect(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -299,24 +278,18 @@ func (c *DatabaseController) TestDatabaseConnectionDirect(ctx *gin.Context) {
|
||||
// @Failure 500
|
||||
// @Router /databases/notifier/{id}/is-using [get]
|
||||
func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid notifier ID"})
|
||||
return
|
||||
}
|
||||
|
||||
authorizationHeader := ctx.GetHeader("Authorization")
|
||||
if authorizationHeader == "" {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "authorization header is required"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := c.userService.GetUserFromToken(authorizationHeader)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "invalid token"})
|
||||
return
|
||||
}
|
||||
|
||||
isUsing, err := c.databaseService.IsNotifierUsing(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -325,3 +298,109 @@ func (c *DatabaseController) IsNotifierUsing(ctx *gin.Context) {
|
||||
|
||||
ctx.JSON(http.StatusOK, gin.H{"isUsing": isUsing})
|
||||
}
|
||||
|
||||
// CopyDatabase
|
||||
// @Summary Copy a database
|
||||
// @Description Copy an existing database configuration
|
||||
// @Tags databases
|
||||
// @Produce json
|
||||
// @Param id path string true "Database ID"
|
||||
// @Success 201 {object} Database
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 500
|
||||
// @Router /databases/{id}/copy [post]
|
||||
func (c *DatabaseController) CopyDatabase(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
id, err := uuid.Parse(ctx.Param("id"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid database ID"})
|
||||
return
|
||||
}
|
||||
|
||||
copiedDatabase, err := c.databaseService.CopyDatabase(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusCreated, copiedDatabase)
|
||||
}
|
||||
|
||||
// IsUserReadOnly
|
||||
// @Summary Check if database user is read-only
|
||||
// @Description Check if current database credentials have only read (SELECT) privileges
|
||||
// @Tags databases
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body Database true "Database configuration to check"
|
||||
// @Success 200 {object} IsReadOnlyResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /databases/is-readonly [post]
|
||||
func (c *DatabaseController) IsUserReadOnly(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
isReadOnly, err := c.databaseService.IsUserReadOnly(user, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, IsReadOnlyResponse{IsReadOnly: isReadOnly})
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser
|
||||
// @Summary Create read-only database user
|
||||
// @Description Create a new PostgreSQL user with read-only privileges for backup operations
|
||||
// @Tags databases
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security BearerAuth
|
||||
// @Param request body Database true "Database configuration to create user for"
|
||||
// @Success 200 {object} CreateReadOnlyUserResponse
|
||||
// @Failure 400 {object} map[string]string
|
||||
// @Failure 401 {object} map[string]string
|
||||
// @Failure 403 {object} map[string]string
|
||||
// @Router /databases/create-readonly-user [post]
|
||||
func (c *DatabaseController) CreateReadOnlyUser(ctx *gin.Context) {
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var request Database
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
username, password, err := c.databaseService.CreateReadOnlyUser(user, &request)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, CreateReadOnlyUserResponse{
|
||||
Username: username,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
|
||||
1002
backend/internal/features/databases/controller_test.go
Normal file
1002
backend/internal/features/databases/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -5,8 +5,10 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -17,7 +19,6 @@ type PostgresqlDatabase struct {
|
||||
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
|
||||
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
|
||||
RestoreID *uuid.UUID `json:"restoreId" gorm:"type:uuid;column:restore_id"`
|
||||
|
||||
Version tools.PostgresqlVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
@@ -58,11 +59,429 @@ func (p *PostgresqlDatabase) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) TestConnection(logger *slog.Logger) error {
|
||||
func (p *PostgresqlDatabase) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return testSingleDatabaseConnection(logger, ctx, p)
|
||||
return testSingleDatabaseConnection(logger, ctx, p, encryptor, databaseID)
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) HideSensitiveData() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.Password = ""
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
||||
p.Version = incoming.Version
|
||||
p.Host = incoming.Host
|
||||
p.Port = incoming.Port
|
||||
p.Username = incoming.Username
|
||||
p.Database = incoming.Database
|
||||
p.IsHttps = incoming.IsHttps
|
||||
|
||||
if incoming.Password != "" {
|
||||
p.Password = incoming.Password
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
databaseID uuid.UUID,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if p.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(databaseID, p.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Password = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsUserReadOnly checks if the database user has read-only privileges.
|
||||
//
|
||||
// This method performs a comprehensive security check by examining:
|
||||
// - Role-level attributes (superuser, createrole, createdb)
|
||||
// - Database-level privileges (CREATE, TEMP)
|
||||
// - Table-level write permissions (INSERT, UPDATE, DELETE, TRUNCATE, REFERENCES, TRIGGER)
|
||||
//
|
||||
// A user is considered read-only only if they have ZERO write privileges
|
||||
// across all three levels. This ensures the database user follows the
|
||||
// principle of least privilege for backup operations.
|
||||
func (p *PostgresqlDatabase) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
connStr := buildConnectionStringForDB(p, *p.Database, password)
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// LEVEL 1: Check role-level attributes
|
||||
var isSuperuser, canCreateRole, canCreateDB bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT
|
||||
rolsuper,
|
||||
rolcreaterole,
|
||||
rolcreatedb
|
||||
FROM pg_roles
|
||||
WHERE rolname = current_user
|
||||
`).Scan(&isSuperuser, &canCreateRole, &canCreateDB)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check role attributes: %w", err)
|
||||
}
|
||||
|
||||
if isSuperuser || canCreateRole || canCreateDB {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// LEVEL 2: Check database-level privileges
|
||||
var canCreate, canTemp bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT
|
||||
has_database_privilege(current_user, current_database(), 'CREATE') as can_create,
|
||||
has_database_privilege(current_user, current_database(), 'TEMP') as can_temp
|
||||
`).Scan(&canCreate, &canTemp)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check database privileges: %w", err)
|
||||
}
|
||||
|
||||
if canCreate || canTemp {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// LEVEL 2.5: Check schema-level CREATE privileges
|
||||
schemaRows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT nspname
|
||||
FROM pg_namespace n
|
||||
WHERE has_schema_privilege(current_user, n.nspname, 'CREATE')
|
||||
AND nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check schema privileges: %w", err)
|
||||
}
|
||||
defer schemaRows.Close()
|
||||
|
||||
// If user has CREATE privilege on any schema, they're not read-only
|
||||
if schemaRows.Next() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := schemaRows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating schema privileges: %w", err)
|
||||
}
|
||||
|
||||
// LEVEL 3: Check table-level write permissions
|
||||
rows, err := conn.Query(ctx, `
|
||||
SELECT DISTINCT privilege_type
|
||||
FROM information_schema.role_table_grants
|
||||
WHERE grantee = current_user
|
||||
AND table_schema NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check table privileges: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
writePrivileges := map[string]bool{
|
||||
"INSERT": true,
|
||||
"UPDATE": true,
|
||||
"DELETE": true,
|
||||
"TRUNCATE": true,
|
||||
"REFERENCES": true,
|
||||
"TRIGGER": true,
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var privilege string
|
||||
if err := rows.Scan(&privilege); err != nil {
|
||||
return false, fmt.Errorf("failed to scan privilege: %w", err)
|
||||
}
|
||||
|
||||
if writePrivileges[privilege] {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating privileges: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// CreateReadOnlyUser creates a new PostgreSQL user with read-only privileges.
|
||||
//
|
||||
// This method performs the following operations atomically in a single transaction:
|
||||
// 1. Creates a PostgreSQL user with a UUID-based password
|
||||
// 2. Grants CONNECT privilege on the database
|
||||
// 3. Grants USAGE on all non-system schemas
|
||||
// 4. Grants SELECT on all existing tables and sequences
|
||||
// 5. Sets default privileges for future tables and sequences
|
||||
//
|
||||
// Security features:
|
||||
// - Username format: "postgresus-{8-char-uuid}" for uniqueness
|
||||
// - Password: Full UUID (36 characters) for strong entropy
|
||||
// - Transaction safety: All operations rollback on any failure
|
||||
// - Retry logic: Up to 3 attempts if username collision occurs
|
||||
// - Pre-validation: Checks CREATEROLE privilege before starting transaction
|
||||
func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, string, error) {
|
||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
connStr := buildConnectionStringForDB(p, *p.Database, password)
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
// Pre-validate: Check if current user can create roles
|
||||
var canCreateRole, isSuperuser bool
|
||||
err = conn.QueryRow(ctx, `
|
||||
SELECT rolcreaterole, rolsuper
|
||||
FROM pg_roles
|
||||
WHERE rolname = current_user
|
||||
`).Scan(&canCreateRole, &isSuperuser)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to check permissions: %w", err)
|
||||
}
|
||||
if !canCreateRole && !isSuperuser {
|
||||
return "", "", errors.New("current database user lacks CREATEROLE privilege")
|
||||
}
|
||||
|
||||
// Retry logic for username collision
|
||||
maxRetries := 3
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
username := fmt.Sprintf("postgresus-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
if rollbackErr := tx.Rollback(ctx); rollbackErr != nil {
|
||||
logger.Error("Failed to rollback transaction", "error", rollbackErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Step 1: Create PostgreSQL user with LOGIN privilege
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`, username, newPassword),
|
||||
)
|
||||
if err != nil {
|
||||
if err.Error() != "" && attempt < maxRetries-1 {
|
||||
continue
|
||||
}
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Step 1.5: Revoke CREATE privilege from PUBLIC role on public schema
|
||||
// This is necessary because all PostgreSQL users inherit CREATE privilege on the
|
||||
// public schema through the PUBLIC role. This is a one-time operation that affects
|
||||
// the entire database, making it more secure by default.
|
||||
// Note: This only affects the public schema; other schemas are unaffected.
|
||||
_, err = tx.Exec(ctx, `REVOKE CREATE ON SCHEMA public FROM PUBLIC`)
|
||||
if err != nil {
|
||||
logger.Error("Failed to revoke CREATE on public from PUBLIC", "error", err)
|
||||
if !strings.Contains(err.Error(), "schema \"public\" does not exist") &&
|
||||
!strings.Contains(err.Error(), "permission denied") {
|
||||
return "", "", fmt.Errorf("failed to revoke CREATE from PUBLIC: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, username))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
username,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE %s TO "%s"`, *p.Database, username),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant connect privilege: %w", err)
|
||||
}
|
||||
|
||||
// Revoke TEMP privilege from PUBLIC role (like CREATE on public schema, TEMP is granted to PUBLIC by default)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE TEMP ON DATABASE %s FROM PUBLIC`, *p.Database))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to revoke TEMP from PUBLIC", "error", err)
|
||||
}
|
||||
|
||||
// Also revoke from the specific user (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE TEMP ON DATABASE %s FROM "%s"`, *p.Database, username),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", username)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
rows, err := tx.Query(ctx, `
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
`)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to get schemas: %w", err)
|
||||
}
|
||||
|
||||
var schemas []string
|
||||
for rows.Next() {
|
||||
var schema string
|
||||
if err := rows.Scan(&schema); err != nil {
|
||||
rows.Close()
|
||||
return "", "", fmt.Errorf("failed to scan schema: %w", err)
|
||||
}
|
||||
schemas = append(schemas, schema)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return "", "", fmt.Errorf("error iterating schemas: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Grant USAGE on each schema and explicitly prevent CREATE
|
||||
for _, schema := range schemas {
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA "%s" FROM "%s"`, schema, username),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
"Failed to revoke CREATE on schema",
|
||||
"error",
|
||||
err,
|
||||
"schema",
|
||||
schema,
|
||||
"username",
|
||||
username,
|
||||
)
|
||||
}
|
||||
|
||||
// Grant only USAGE (not CREATE)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT USAGE ON SCHEMA "%s" TO "%s"`, schema, username),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant usage on schema %s: %w", schema, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Grant SELECT on ALL existing tables and sequences
|
||||
grantSelectSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('GRANT SELECT ON ALL TABLES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, username, username)
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant select on tables: %w", err)
|
||||
}
|
||||
|
||||
// Step 6: Set default privileges for FUTURE tables and sequences
|
||||
defaultPrivilegesSQL := fmt.Sprintf(`
|
||||
DO $$
|
||||
DECLARE
|
||||
schema_rec RECORD;
|
||||
BEGIN
|
||||
FOR schema_rec IN
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema')
|
||||
LOOP
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON TABLES TO "%s"', schema_rec.schema_name);
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, username, username)
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to set default privileges: %w", err)
|
||||
}
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, username)).
|
||||
Scan(&verifyUsername)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to verify user creation: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
logger.Info("Read-only user created successfully", "username", username)
|
||||
return username, newPassword, nil
|
||||
}
|
||||
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
// testSingleDatabaseConnection tests connection to a specific database for pg_dump
|
||||
@@ -70,14 +489,22 @@ func testSingleDatabaseConnection(
|
||||
logger *slog.Logger,
|
||||
ctx context.Context,
|
||||
postgresDb *PostgresqlDatabase,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
// For single database backup, we need to connect to the specific database
|
||||
if postgresDb.Database == nil || *postgresDb.Database == "" {
|
||||
return errors.New("database name is required for single database backup (pg_dump)")
|
||||
}
|
||||
|
||||
// Decrypt password if needed
|
||||
password, err := decryptPasswordIfNeeded(postgresDb.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
// Build connection string for the specific database
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database)
|
||||
connStr := buildConnectionStringForDB(postgresDb, *postgresDb.Database, password)
|
||||
|
||||
// Test connection
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
@@ -160,7 +587,7 @@ func testBasicOperations(ctx context.Context, conn *pgx.Conn, dbName string) err
|
||||
}
|
||||
|
||||
// buildConnectionStringForDB builds connection string for specific database
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password string) string {
|
||||
sslMode := "disable"
|
||||
if p.IsHttps {
|
||||
sslMode = "require"
|
||||
@@ -170,8 +597,19 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string) string {
|
||||
p.Host,
|
||||
p.Port,
|
||||
p.Username,
|
||||
p.Password,
|
||||
password,
|
||||
dbName,
|
||||
sslMode,
|
||||
)
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, error) {
|
||||
if encryptor == nil {
|
||||
return password, nil
|
||||
}
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,323 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := pgModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Admin user should not be read-only")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version string
|
||||
port string
|
||||
}{
|
||||
{"PostgreSQL 12", "12", env.TestPostgres12Port},
|
||||
{"PostgreSQL 13", "13", env.TestPostgres13Port},
|
||||
{"PostgreSQL 14", "14", env.TestPostgres14Port},
|
||||
{"PostgreSQL 15", "15", env.TestPostgres15Port},
|
||||
{"PostgreSQL 16", "16", env.TestPostgres16Port},
|
||||
{"PostgreSQL 17", "17", env.TestPostgres17Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToPostgresContainer(t, tc.port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
DROP TABLE IF EXISTS readonly_test CASCADE;
|
||||
DROP TABLE IF EXISTS hack_table CASCADE;
|
||||
DROP TABLE IF EXISTS future_table CASCADE;
|
||||
CREATE TABLE readonly_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO readonly_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "postgresus-"))
|
||||
|
||||
readOnlyModel := &PostgresqlDatabase{
|
||||
Version: pgModel.Version,
|
||||
Host: pgModel.Host,
|
||||
Port: pgModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: pgModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host,
|
||||
container.Port,
|
||||
username,
|
||||
password,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_HaveSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO future_table (data) VALUES ('future_data');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`
|
||||
CREATE SCHEMA IF NOT EXISTS schema_a;
|
||||
CREATE SCHEMA IF NOT EXISTS schema_b;
|
||||
CREATE TABLE schema_a.table_a (id INT, data TEXT);
|
||||
CREATE TABLE schema_b.table_b (id INT, data TEXT);
|
||||
INSERT INTO schema_a.table_a VALUES (1, 'data_a');
|
||||
INSERT INTO schema_b.table_b VALUES (2, 'data_b');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := createPostgresModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var dataA string
|
||||
err = readOnlyConn.Get(&dataA, "SELECT data FROM schema_a.table_a LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_a", dataA)
|
||||
|
||||
var dataB string
|
||||
err = readOnlyConn.Get(&dataB, "SELECT data FROM schema_b.table_b LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "data_b", dataB)
|
||||
|
||||
// Clean up: Drop user with CASCADE to handle default privilege dependencies
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP SCHEMA schema_a CASCADE; DROP SCHEMA schema_b CASCADE;`)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToPostgresContainer(t *testing.T, port string) *PostgresContainer {
|
||||
dbName := "testdb"
|
||||
password := "testpassword"
|
||||
username := "testuser"
|
||||
host := "localhost"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
host, portInt, username, password, dbName)
|
||||
|
||||
db, err := sqlx.Connect("postgres", dsn)
|
||||
assert.NoError(t, err)
|
||||
|
||||
var versionStr string
|
||||
err = db.Get(&versionStr, "SELECT version()")
|
||||
assert.NoError(t, err)
|
||||
|
||||
return &PostgresContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createPostgresModel(container *PostgresContainer) *PostgresqlDatabase {
|
||||
var versionStr string
|
||||
err := container.DB.Get(&versionStr, "SELECT version()")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
version := extractPostgresVersion(versionStr)
|
||||
|
||||
return &PostgresqlDatabase{
|
||||
Version: version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
|
||||
func extractPostgresVersion(versionStr string) tools.PostgresqlVersion {
|
||||
if strings.Contains(versionStr, "PostgreSQL 12") {
|
||||
return tools.GetPostgresqlVersionEnum("12")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 13") {
|
||||
return tools.GetPostgresqlVersionEnum("13")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 14") {
|
||||
return tools.GetPostgresqlVersionEnum("14")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 15") {
|
||||
return tools.GetPostgresqlVersionEnum("15")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 16") {
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
} else if strings.Contains(versionStr, "PostgreSQL 17") {
|
||||
return tools.GetPostgresqlVersionEnum("17")
|
||||
}
|
||||
|
||||
return tools.GetPostgresqlVersionEnum("16")
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -14,11 +17,16 @@ var databaseService = &DatabaseService{
|
||||
logger.GetLogger(),
|
||||
[]DatabaseCreationListener{},
|
||||
[]DatabaseRemoveListener{},
|
||||
[]DatabaseCopyListener{},
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
var databaseController = &DatabaseController{
|
||||
databaseService,
|
||||
users.GetUserService(),
|
||||
users_services.GetUserService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
func GetDatabaseService() *DatabaseService {
|
||||
@@ -28,3 +36,7 @@ func GetDatabaseService() *DatabaseService {
|
||||
func GetDatabaseController() *DatabaseController {
|
||||
return databaseController
|
||||
}
|
||||
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(databaseService)
|
||||
}
|
||||
|
||||
10
backend/internal/features/databases/dto.go
Normal file
10
backend/internal/features/databases/dto.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package databases
|
||||
|
||||
type CreateReadOnlyUserResponse struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type IsReadOnlyResponse struct {
|
||||
IsReadOnly bool `json:"isReadOnly"`
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package databases
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -11,7 +12,13 @@ type DatabaseValidator interface {
|
||||
}
|
||||
|
||||
type DatabaseConnector interface {
|
||||
TestConnection(logger *slog.Logger) error
|
||||
TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error
|
||||
|
||||
HideSensitiveData()
|
||||
}
|
||||
|
||||
type DatabaseCreationListener interface {
|
||||
@@ -21,3 +28,7 @@ type DatabaseCreationListener interface {
|
||||
type DatabaseRemoveListener interface {
|
||||
OnBeforeDatabaseRemove(databaseID uuid.UUID) error
|
||||
}
|
||||
|
||||
type DatabaseCopyListener interface {
|
||||
OnDatabaseCopied(originalDatabaseID, newDatabaseID uuid.UUID)
|
||||
}
|
||||
|
||||
@@ -5,16 +5,20 @@ import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
UserID uuid.UUID `json:"userId" gorm:"column:user_id;type:uuid;not null"`
|
||||
Name string `json:"name" gorm:"column:name;type:text;not null"`
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
|
||||
// WorkspaceID can be null when a database is created via restore operation
|
||||
// outside the context of any workspace
|
||||
WorkspaceID *uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;type:uuid"`
|
||||
Name string `json:"name" gorm:"column:name;type:text;not null"`
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
|
||||
@@ -35,6 +39,10 @@ func (d *Database) Validate() error {
|
||||
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if d.Postgresql == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
|
||||
return d.Postgresql.Validate()
|
||||
default:
|
||||
return errors.New("invalid database type: " + string(d.Type))
|
||||
@@ -49,8 +57,35 @@ func (d *Database) ValidateUpdate(old, new Database) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) TestConnection(logger *slog.Logger) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger)
|
||||
func (d *Database) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
return d.getSpecificDatabase().TestConnection(logger, encryptor, d.ID)
|
||||
}
|
||||
|
||||
func (d *Database) HideSensitiveData() {
|
||||
d.getSpecificDatabase().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) Update(incoming *Database) {
|
||||
d.Name = incoming.Name
|
||||
d.Type = incoming.Type
|
||||
d.Notifiers = incoming.Notifiers
|
||||
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if d.Postgresql != nil && incoming.Postgresql != nil {
|
||||
d.Postgresql.Update(incoming.Postgresql)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) getSpecificDatabase() DatabaseConnector {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/storage"
|
||||
|
||||
@@ -21,9 +22,12 @@ func (r *DatabaseRepository) Save(database *Database) (*Database, error) {
|
||||
err := db.Transaction(func(tx *gorm.DB) error {
|
||||
switch database.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if database.Postgresql != nil {
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
if database.Postgresql == nil {
|
||||
return errors.New("postgresql configuration is required for PostgreSQL database")
|
||||
}
|
||||
|
||||
// Ensure DatabaseID is always set and never nil
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
}
|
||||
|
||||
if isNew {
|
||||
@@ -43,17 +47,15 @@ func (r *DatabaseRepository) Save(database *Database) (*Database, error) {
|
||||
// Save the specific database type
|
||||
switch database.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if database.Postgresql != nil {
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
if database.Postgresql.ID == uuid.Nil {
|
||||
database.Postgresql.ID = uuid.New()
|
||||
if err := tx.Create(database.Postgresql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(database.Postgresql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
if database.Postgresql.ID == uuid.Nil {
|
||||
database.Postgresql.ID = uuid.New()
|
||||
if err := tx.Create(database.Postgresql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(database.Postgresql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,14 +92,14 @@ func (r *DatabaseRepository) FindByID(id uuid.UUID) (*Database, error) {
|
||||
return &database, nil
|
||||
}
|
||||
|
||||
func (r *DatabaseRepository) FindByUserID(userID uuid.UUID) ([]*Database, error) {
|
||||
func (r *DatabaseRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Database, error) {
|
||||
var databases []*Database
|
||||
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Postgresql").
|
||||
Preload("Notifiers").
|
||||
Where("user_id = ?", userID).
|
||||
Where("workspace_id = ?", workspaceID).
|
||||
Order("CASE WHEN health_status = 'UNAVAILABLE' THEN 1 WHEN health_status = 'AVAILABLE' THEN 2 WHEN health_status IS NULL THEN 3 ELSE 4 END, name ASC").
|
||||
Find(&databases).Error; err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
package databases
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
"time"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -17,6 +24,11 @@ type DatabaseService struct {
|
||||
|
||||
dbCreationListener []DatabaseCreationListener
|
||||
dbRemoveListener []DatabaseRemoveListener
|
||||
dbCopyListener []DatabaseCopyListener
|
||||
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *DatabaseService) AddDbCreationListener(
|
||||
@@ -31,17 +43,36 @@ func (s *DatabaseService) AddDbRemoveListener(
|
||||
s.dbRemoveListener = append(s.dbRemoveListener, dbRemoveListener)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) AddDbCopyListener(
|
||||
dbCopyListener DatabaseCopyListener,
|
||||
) {
|
||||
s.dbCopyListener = append(s.dbCopyListener, dbCopyListener)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateDatabase(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
database *Database,
|
||||
) (*Database, error) {
|
||||
database.UserID = user.ID
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(workspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to create database in this workspace")
|
||||
}
|
||||
|
||||
database.WorkspaceID = &workspaceID
|
||||
|
||||
if err := database.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database, err := s.dbRepository.Save(database)
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
database, err = s.dbRepository.Save(database)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -50,6 +81,12 @@ func (s *DatabaseService) CreateDatabase(
|
||||
listener.OnDatabaseCreated(database.ID)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database created: %s", database.Name),
|
||||
&user.ID,
|
||||
&workspaceID,
|
||||
)
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
@@ -66,24 +103,43 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if existingDatabase.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return errors.New("cannot update database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to update this database")
|
||||
}
|
||||
|
||||
// Validate the update
|
||||
if err := database.ValidateUpdate(*existingDatabase, *database); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := database.Validate(); err != nil {
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s.dbRepository.Save(database)
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.dbRepository.Save(existingDatabase)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database updated: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -96,8 +152,16 @@ func (s *DatabaseService) DeleteDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if existingDatabase.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return errors.New("cannot delete database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to delete this database")
|
||||
}
|
||||
|
||||
for _, listener := range s.dbRemoveListener {
|
||||
@@ -106,6 +170,12 @@ func (s *DatabaseService) DeleteDatabase(
|
||||
}
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database deleted: %s", existingDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.dbRepository.Delete(id)
|
||||
}
|
||||
|
||||
@@ -118,17 +188,44 @@ func (s *DatabaseService) GetDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("you have not access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access this database")
|
||||
}
|
||||
|
||||
database.HideSensitiveData()
|
||||
return database, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetDatabasesByUser(
|
||||
func (s *DatabaseService) GetDatabasesByWorkspace(
|
||||
user *users_models.User,
|
||||
workspaceID uuid.UUID,
|
||||
) ([]*Database, error) {
|
||||
return s.dbRepository.FindByUserID(user.ID)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(workspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to access this workspace")
|
||||
}
|
||||
|
||||
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, database := range databases {
|
||||
database.HideSensitiveData()
|
||||
}
|
||||
|
||||
return databases, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) IsNotifierUsing(
|
||||
@@ -152,11 +249,19 @@ func (s *DatabaseService) TestDatabaseConnection(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("you have not access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot test connection for database without workspace")
|
||||
}
|
||||
|
||||
err = database.TestConnection(s.logger)
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canAccess {
|
||||
return errors.New("insufficient permissions to test connection for this database")
|
||||
}
|
||||
|
||||
err = database.TestConnection(s.logger, s.fieldEncryptor)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
database.LastBackupErrorMessage = &lastSaveError
|
||||
@@ -176,7 +281,31 @@ func (s *DatabaseService) TestDatabaseConnection(
|
||||
func (s *DatabaseService) TestDatabaseConnectionDirect(
|
||||
database *Database,
|
||||
) error {
|
||||
return database.TestConnection(s.logger)
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && existingDatabase.WorkspaceID != nil &&
|
||||
*existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
} else {
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
return usingDatabase.TestConnection(s.logger, s.fieldEncryptor)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) GetDatabaseByID(
|
||||
@@ -220,6 +349,81 @@ func (s *DatabaseService) SetLastBackupTime(databaseID uuid.UUID, backupTime tim
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CopyDatabase(
|
||||
user *users_models.User,
|
||||
databaseID uuid.UUID,
|
||||
) (*Database, error) {
|
||||
existingDatabase, err := s.dbRepository.FindByID(databaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot copy database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canManage {
|
||||
return nil, errors.New("insufficient permissions to copy this database")
|
||||
}
|
||||
|
||||
newDatabase := &Database{
|
||||
ID: uuid.Nil,
|
||||
WorkspaceID: existingDatabase.WorkspaceID,
|
||||
Name: existingDatabase.Name + " (Copy)",
|
||||
Type: existingDatabase.Type,
|
||||
Notifiers: existingDatabase.Notifiers,
|
||||
LastBackupTime: nil,
|
||||
LastBackupErrorMessage: nil,
|
||||
HealthStatus: existingDatabase.HealthStatus,
|
||||
}
|
||||
|
||||
switch existingDatabase.Type {
|
||||
case DatabaseTypePostgres:
|
||||
if existingDatabase.Postgresql != nil {
|
||||
newDatabase.Postgresql = &postgresql.PostgresqlDatabase{
|
||||
ID: uuid.Nil,
|
||||
DatabaseID: nil,
|
||||
Version: existingDatabase.Postgresql.Version,
|
||||
Host: existingDatabase.Postgresql.Host,
|
||||
Port: existingDatabase.Postgresql.Port,
|
||||
Username: existingDatabase.Postgresql.Username,
|
||||
Password: existingDatabase.Postgresql.Password,
|
||||
Database: existingDatabase.Postgresql.Database,
|
||||
IsHttps: existingDatabase.Postgresql.IsHttps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := newDatabase.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
copiedDatabase, err := s.dbRepository.Save(newDatabase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, listener := range s.dbCreationListener {
|
||||
listener.OnDatabaseCreated(copiedDatabase.ID)
|
||||
}
|
||||
|
||||
for _, listener := range s.dbCopyListener {
|
||||
listener.OnDatabaseCopied(databaseID, copiedDatabase.ID)
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Database copied: %s to %s", existingDatabase.Name, copiedDatabase.Name),
|
||||
&user.ID,
|
||||
existingDatabase.WorkspaceID,
|
||||
)
|
||||
|
||||
return copiedDatabase, nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) SetHealthStatus(
|
||||
databaseID uuid.UUID,
|
||||
healthStatus *HealthStatus,
|
||||
@@ -237,3 +441,164 @@ func (s *DatabaseService) SetHealthStatus(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) OnBeforeWorkspaceDeletion(workspaceID uuid.UUID) error {
|
||||
databases, err := s.dbRepository.FindByWorkspaceID(workspaceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(databases) > 0 {
|
||||
return fmt.Errorf(
|
||||
"workspace contains %d databases that must be deleted",
|
||||
len(databases),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DatabaseService) IsUserReadOnly(
|
||||
user *users_models.User,
|
||||
database *Database,
|
||||
) (bool, error) {
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return false, errors.New("cannot check user for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
*existingDatabase.WorkspaceID,
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this database")
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return false, errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
} else {
|
||||
if database.WorkspaceID != nil {
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !canAccess {
|
||||
return false, errors.New("insufficient permissions to access this workspace")
|
||||
}
|
||||
}
|
||||
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
if usingDatabase.Type != DatabaseTypePostgres {
|
||||
return false, errors.New("read-only check only supported for PostgreSQL databases")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return usingDatabase.Postgresql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateReadOnlyUser(
|
||||
user *users_models.User,
|
||||
database *Database,
|
||||
) (string, string, error) {
|
||||
var usingDatabase *Database
|
||||
|
||||
if database.ID != uuid.Nil {
|
||||
existingDatabase, err := s.dbRepository.FindByID(database.ID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if existingDatabase.WorkspaceID == nil {
|
||||
return "", "", errors.New("cannot create user for database without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*existingDatabase.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if !canManage {
|
||||
return "", "", errors.New("insufficient permissions to manage this database")
|
||||
}
|
||||
|
||||
if database.WorkspaceID != nil && *existingDatabase.WorkspaceID != *database.WorkspaceID {
|
||||
return "", "", errors.New("database does not belong to this workspace")
|
||||
}
|
||||
|
||||
existingDatabase.Update(database)
|
||||
|
||||
if err := existingDatabase.Validate(); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
usingDatabase = existingDatabase
|
||||
} else {
|
||||
if database.WorkspaceID != nil {
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, user)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if !canManage {
|
||||
return "", "", errors.New("insufficient permissions to manage this workspace")
|
||||
}
|
||||
}
|
||||
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
if usingDatabase.Type != DatabaseTypePostgres {
|
||||
return "", "", errors.New("read-only user creation only supported for PostgreSQL")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
username, password, err := usingDatabase.Postgresql.CreateReadOnlyUser(
|
||||
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
if usingDatabase.WorkspaceID != nil {
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf(
|
||||
"Read-only user created for database: %s (username: %s)",
|
||||
usingDatabase.Name,
|
||||
username,
|
||||
),
|
||||
&user.ID,
|
||||
usingDatabase.WorkspaceID,
|
||||
)
|
||||
}
|
||||
|
||||
return username, password, nil
|
||||
}
|
||||
|
||||
@@ -10,14 +10,14 @@ import (
|
||||
)
|
||||
|
||||
func CreateTestDatabase(
|
||||
userID uuid.UUID,
|
||||
workspaceID uuid.UUID,
|
||||
storage *storages.Storage,
|
||||
notifier *notifiers.Notifier,
|
||||
) *Database {
|
||||
database := &Database{
|
||||
UserID: userID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "test " + uuid.New().String(),
|
||||
Type: DatabaseTypePostgres,
|
||||
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
|
||||
9
backend/internal/features/encryption/secrets/di.go
Normal file
9
backend/internal/features/encryption/secrets/di.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package secrets
|
||||
|
||||
var secretKeyService = &SecretKeyService{
|
||||
nil,
|
||||
}
|
||||
|
||||
func GetSecretKeyService() *SecretKeyService {
|
||||
return secretKeyService
|
||||
}
|
||||
1
backend/internal/features/encryption/secrets/model.go
Normal file
1
backend/internal/features/encryption/secrets/model.go
Normal file
@@ -0,0 +1 @@
|
||||
package secrets
|
||||
73
backend/internal/features/encryption/secrets/service.go
Normal file
73
backend/internal/features/encryption/secrets/service.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
user_models "postgresus-backend/internal/features/users/models"
|
||||
"postgresus-backend/internal/storage"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type SecretKeyService struct {
|
||||
cachedKey *string
|
||||
}
|
||||
|
||||
func (s *SecretKeyService) MigrateKeyFromDbToFileIfExist() error {
|
||||
var secretKey user_models.SecretKey
|
||||
|
||||
err := storage.GetDb().First(&secretKey).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("failed to check for secret key in database: %w", err)
|
||||
}
|
||||
|
||||
if secretKey.Secret == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
secretKeyPath := config.GetEnv().SecretKeyPath
|
||||
if err := os.WriteFile(secretKeyPath, []byte(secretKey.Secret), 0600); err != nil {
|
||||
return fmt.Errorf("failed to write secret key to file: %w", err)
|
||||
}
|
||||
|
||||
if err := storage.GetDb().Exec("DELETE FROM secret_keys").Error; err != nil {
|
||||
return fmt.Errorf("failed to delete secret key from database: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SecretKeyService) GetSecretKey() (string, error) {
|
||||
if s.cachedKey != nil {
|
||||
return *s.cachedKey, nil
|
||||
}
|
||||
|
||||
secretKeyPath := config.GetEnv().SecretKeyPath
|
||||
data, err := os.ReadFile(secretKeyPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
newKey := s.generateNewSecretKey()
|
||||
if err := os.WriteFile(secretKeyPath, []byte(newKey), 0600); err != nil {
|
||||
return "", fmt.Errorf("failed to write new secret key: %w", err)
|
||||
}
|
||||
s.cachedKey = &newKey
|
||||
return newKey, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to read secret key file: %w", err)
|
||||
}
|
||||
|
||||
key := string(data)
|
||||
s.cachedKey = &key
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (s *SecretKeyService) generateNewSecretKey() string {
|
||||
return uuid.New().String() + uuid.New().String()
|
||||
}
|
||||
@@ -13,7 +13,7 @@ type HealthcheckAttemptBackgroundService struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptBackgroundService) RunBackgroundTasks() {
|
||||
func (s *HealthcheckAttemptBackgroundService) Run() {
|
||||
// first healthcheck immediately
|
||||
s.checkDatabases()
|
||||
|
||||
|
||||
@@ -10,23 +10,34 @@ import (
|
||||
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
user := users.GetTestUser()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
|
||||
storage := storages.CreateTestStorage(user.UserID)
|
||||
notifier := notifiers.CreateTestNotifier(user.UserID)
|
||||
// Create workspace directly via service
|
||||
workspace, err := workspaces_testing.CreateTestWorkspaceDirect("Test Workspace", user.UserID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create workspace: %v", err)
|
||||
}
|
||||
|
||||
defer storages.RemoveTestStorage(storage.ID)
|
||||
defer notifiers.RemoveTestNotifier(notifier)
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
notifier := notifiers.CreateTestNotifier(workspace.ID)
|
||||
|
||||
defer func() {
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
notifiers.RemoveTestNotifier(notifier)
|
||||
workspaces_testing.RemoveTestWorkspaceDirect(workspace.ID)
|
||||
}()
|
||||
|
||||
t.Run("Test_DbAttemptFailed_DbMarkedAsUnavailable", func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
@@ -94,7 +105,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbNotMarkerdAsDownAfterFirstAttempt",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
@@ -160,7 +171,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbShouldBeConsideredAsDownOnThirdFailedAttempt_DbMarkerdAsDownAfterThirdFailedAttempt",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Make sure DB is available
|
||||
@@ -237,7 +248,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
)
|
||||
|
||||
t.Run("Test_UnavailableDbAttemptSucceed_DbMarkedAsAvailable", func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Make sure DB is unavailable
|
||||
@@ -311,7 +322,7 @@ func Test_CheckPgHealthUseCase(t *testing.T) {
|
||||
t.Run(
|
||||
"Test_DbHealthcheckExecutedFast_HealthcheckNotExecutedFasterThanInterval",
|
||||
func(t *testing.T) {
|
||||
database := databases.CreateTestDatabase(user.UserID, storage, notifier)
|
||||
database := databases.CreateTestDatabase(workspace.ID, storage, notifier)
|
||||
defer databases.RemoveTestDatabase(database)
|
||||
|
||||
// Setup mock notifier sender
|
||||
|
||||
@@ -2,7 +2,7 @@ package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
|
||||
type HealthcheckAttemptController struct {
|
||||
healthcheckAttemptService *HealthcheckAttemptService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -31,9 +30,9 @@ func (c *HealthcheckAttemptController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-attempts/{databaseId} [get]
|
||||
func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -43,7 +42,7 @@ func (c *HealthcheckAttemptController) GetAttemptsByDatabase(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
afterDate := time.Now().UTC()
|
||||
afterDate := time.Now().UTC().Add(-7 * 24 * time.Hour)
|
||||
if afterDateStr := ctx.Query("afterDate"); afterDateStr != "" {
|
||||
parsedDate, err := time.Parse(time.RFC3339, afterDateStr)
|
||||
if err != nil {
|
||||
|
||||
261
backend/internal/features/healthcheck/attempt/controller_test.go
Normal file
261
backend/internal/features/healthcheck/attempt/controller_test.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package healthcheck_attempt
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetHealthcheckAttemptController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get healthcheck attempts",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get healthcheck attempts",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get healthcheck attempts",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
pastTime := time.Now().UTC().Add(-1 * time.Hour)
|
||||
createTestHealthcheckAttemptWithTime(
|
||||
database.ID,
|
||||
databases.HealthStatusAvailable,
|
||||
pastTime,
|
||||
)
|
||||
createTestHealthcheckAttemptWithTime(
|
||||
database.ID,
|
||||
databases.HealthStatusUnavailable,
|
||||
pastTime.Add(-30*time.Minute),
|
||||
)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.GreaterOrEqual(t, len(response), 2)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "forbidden")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_FiltersByAfterDate(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
oldTime := time.Now().UTC().Add(-2 * time.Hour)
|
||||
recentTime := time.Now().UTC().Add(-30 * time.Minute)
|
||||
|
||||
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusAvailable, oldTime)
|
||||
createTestHealthcheckAttemptWithTime(database.ID, databases.HealthStatusUnavailable, recentTime)
|
||||
createTestHealthcheckAttempt(database.ID, databases.HealthStatusAvailable)
|
||||
|
||||
afterDate := time.Now().UTC().Add(-1 * time.Hour)
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf(
|
||||
"/api/v1/healthcheck-attempts/%s?afterDate=%s",
|
||||
database.ID.String(),
|
||||
afterDate.Format(time.RFC3339),
|
||||
),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, 2, len(response))
|
||||
for _, attempt := range response {
|
||||
assert.True(t, attempt.CreatedAt.After(afterDate) || attempt.CreatedAt.Equal(afterDate))
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetAttemptsByDatabase_ReturnsEmptyListForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response []*HealthcheckAttempt
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-attempts/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, 0, len(response))
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
|
||||
func createTestHealthcheckAttempt(databaseID uuid.UUID, status databases.HealthStatus) {
|
||||
createTestHealthcheckAttemptWithTime(databaseID, status, time.Now().UTC())
|
||||
}
|
||||
|
||||
func createTestHealthcheckAttemptWithTime(
|
||||
databaseID uuid.UUID,
|
||||
status databases.HealthStatus,
|
||||
createdAt time.Time,
|
||||
) {
|
||||
repo := GetHealthcheckAttemptRepository()
|
||||
attempt := &HealthcheckAttempt{
|
||||
ID: uuid.New(),
|
||||
DatabaseID: databaseID,
|
||||
Status: status,
|
||||
CreatedAt: createdAt,
|
||||
}
|
||||
if err := repo.Create(attempt); err != nil {
|
||||
panic("Failed to create test healthcheck attempt: " + err.Error())
|
||||
}
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"postgresus-backend/internal/features/databases"
|
||||
healthcheck_config "postgresus-backend/internal/features/healthcheck/config"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ var healthcheckAttemptRepository = &HealthcheckAttemptRepository{}
|
||||
var healthcheckAttemptService = &HealthcheckAttemptService{
|
||||
healthcheckAttemptRepository,
|
||||
databases.GetDatabaseService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
var checkPgHealthUseCase = &CheckPgHealthUseCase{
|
||||
@@ -27,7 +28,10 @@ var healthcheckAttemptBackgroundService = &HealthcheckAttemptBackgroundService{
|
||||
}
|
||||
var healthcheckAttemptController = &HealthcheckAttemptController{
|
||||
healthcheckAttemptService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetHealthcheckAttemptRepository() *HealthcheckAttemptRepository {
|
||||
return healthcheckAttemptRepository
|
||||
}
|
||||
|
||||
func GetHealthcheckAttemptService() *HealthcheckAttemptService {
|
||||
|
||||
@@ -53,7 +53,7 @@ func (r *HealthcheckAttemptRepository) DeleteOlderThan(
|
||||
Delete(&HealthcheckAttempt{}).Error
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) Insert(
|
||||
func (r *HealthcheckAttemptRepository) Create(
|
||||
attempt *HealthcheckAttempt,
|
||||
) error {
|
||||
if attempt.ID == uuid.Nil {
|
||||
@@ -67,6 +67,12 @@ func (r *HealthcheckAttemptRepository) Insert(
|
||||
return storage.GetDb().Create(attempt).Error
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) Insert(
|
||||
attempt *HealthcheckAttempt,
|
||||
) error {
|
||||
return r.Create(attempt)
|
||||
}
|
||||
|
||||
func (r *HealthcheckAttemptRepository) FindByDatabaseIDWithLimit(
|
||||
databaseID uuid.UUID,
|
||||
limit int,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
type HealthcheckAttemptService struct {
|
||||
healthcheckAttemptRepository *HealthcheckAttemptRepository
|
||||
databaseService *databases.DatabaseService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
|
||||
@@ -24,7 +26,15 @@ func (s *HealthcheckAttemptService) GetAttemptsByDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access healthcheck attempts for databases without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("forbidden")
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package healthcheck_config
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
|
||||
type HealthcheckConfigController struct {
|
||||
healthcheckConfigService *HealthcheckConfigService
|
||||
userService *users.UserService
|
||||
}
|
||||
|
||||
func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -31,9 +30,9 @@ func (c *HealthcheckConfigController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-config [post]
|
||||
func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -65,9 +64,9 @@ func (c *HealthcheckConfigController) SaveHealthcheckConfig(ctx *gin.Context) {
|
||||
// @Failure 401
|
||||
// @Router /healthcheck-config/{databaseId} [get]
|
||||
func (c *HealthcheckConfigController) GetHealthcheckConfig(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
328
backend/internal/features/healthcheck/config/controller_test.go
Normal file
328
backend/internal/features/healthcheck/config/controller_test.go
Normal file
@@ -0,0 +1,328 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_controllers "postgresus-backend/internal/features/workspaces/controllers"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func createTestRouter() *gin.Engine {
|
||||
router := workspaces_testing.CreateTestRouter(
|
||||
workspaces_controllers.GetWorkspaceController(),
|
||||
workspaces_controllers.GetMembershipController(),
|
||||
databases.GetDatabaseController(),
|
||||
GetHealthcheckConfigController(),
|
||||
)
|
||||
return router
|
||||
}
|
||||
|
||||
func Test_SaveHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer cannot save healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "global admin can save healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
}
|
||||
|
||||
request := HealthcheckConfigDTO{
|
||||
DatabaseID: database.ID,
|
||||
IsHealthcheckEnabled: true,
|
||||
IsSentNotificationWhenUnavailable: true,
|
||||
IntervalMinutes: 5,
|
||||
AttemptsBeforeConcideredAsDown: 3,
|
||||
StoreAttemptsDays: 7,
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response map[string]string
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
assert.Contains(t, response["message"], "successfully")
|
||||
} else {
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+testUserToken,
|
||||
request,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SaveHealthcheckConfig_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
|
||||
request := HealthcheckConfigDTO{
|
||||
DatabaseID: database.ID,
|
||||
IsHealthcheckEnabled: true,
|
||||
IsSentNotificationWhenUnavailable: true,
|
||||
IntervalMinutes: 5,
|
||||
AttemptsBeforeConcideredAsDown: 3,
|
||||
StoreAttemptsDays: 7,
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config",
|
||||
"Bearer "+nonMember.Token,
|
||||
request,
|
||||
http.StatusBadRequest,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_PermissionsEnforced(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
workspaceRole *users_enums.WorkspaceRole
|
||||
isGlobalAdmin bool
|
||||
expectSuccess bool
|
||||
expectedStatusCode int
|
||||
}{
|
||||
{
|
||||
name: "workspace owner can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleOwner; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace admin can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleAdmin; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace member can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleMember; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "workspace viewer can get healthcheck config",
|
||||
workspaceRole: func() *users_enums.WorkspaceRole { r := users_enums.WorkspaceRoleViewer; return &r }(),
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "global admin can get healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: true,
|
||||
expectSuccess: true,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "non-member cannot get healthcheck config",
|
||||
workspaceRole: nil,
|
||||
isGlobalAdmin: false,
|
||||
expectSuccess: false,
|
||||
expectedStatusCode: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var testUserToken string
|
||||
if tt.isGlobalAdmin {
|
||||
admin := users_testing.CreateTestUser(users_enums.UserRoleAdmin)
|
||||
testUserToken = admin.Token
|
||||
} else if tt.workspaceRole != nil && *tt.workspaceRole == users_enums.WorkspaceRoleOwner {
|
||||
testUserToken = owner.Token
|
||||
} else if tt.workspaceRole != nil {
|
||||
member := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspaces_testing.AddMemberToWorkspace(workspace, member, *tt.workspaceRole, owner.Token, router)
|
||||
testUserToken = member.Token
|
||||
} else {
|
||||
nonMember := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
testUserToken = nonMember.Token
|
||||
}
|
||||
|
||||
if tt.expectSuccess {
|
||||
var response HealthcheckConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsHealthcheckEnabled)
|
||||
} else {
|
||||
testResp := test_utils.MakeGetRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+testUserToken,
|
||||
tt.expectedStatusCode,
|
||||
)
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetHealthcheckConfig_ReturnsDefaultConfigForNewDatabase(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
database := createTestDatabaseViaAPI("Test Database", workspace.ID, owner.Token, router)
|
||||
|
||||
var response HealthcheckConfig
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/healthcheck-config/"+database.ID.String(),
|
||||
"Bearer "+owner.Token,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
assert.Equal(t, database.ID, response.DatabaseID)
|
||||
assert.True(t, response.IsHealthcheckEnabled)
|
||||
assert.True(t, response.IsSentNotificationWhenUnavailable)
|
||||
assert.Equal(t, 1, response.IntervalMinutes)
|
||||
assert.Equal(t, 3, response.AttemptsBeforeConcideredAsDown)
|
||||
assert.Equal(t, 7, response.StoreAttemptsDays)
|
||||
}
|
||||
|
||||
func createTestDatabaseViaAPI(
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
token string,
|
||||
router *gin.Engine,
|
||||
) *databases.Database {
|
||||
testDbName := "test_db"
|
||||
request := databases.Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: name,
|
||||
Type: databases.DatabaseTypePostgres,
|
||||
Postgresql: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
panic("Failed to create database")
|
||||
}
|
||||
|
||||
var database databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &database); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &database
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package healthcheck_config
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/users"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -10,11 +11,12 @@ var healthcheckConfigRepository = &HealthcheckConfigRepository{}
|
||||
var healthcheckConfigService = &HealthcheckConfigService{
|
||||
databases.GetDatabaseService(),
|
||||
healthcheckConfigRepository,
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
logger.GetLogger(),
|
||||
}
|
||||
var healthcheckConfigController = &HealthcheckConfigController{
|
||||
healthcheckConfigService,
|
||||
users.GetUserService(),
|
||||
}
|
||||
|
||||
func GetHealthcheckConfigService() *HealthcheckConfigService {
|
||||
|
||||
@@ -2,9 +2,12 @@ package healthcheck_config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -12,6 +15,8 @@ import (
|
||||
type HealthcheckConfigService struct {
|
||||
databaseService *databases.DatabaseService
|
||||
healthcheckConfigRepository *HealthcheckConfigRepository
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
@@ -33,8 +38,16 @@ func (s *HealthcheckConfigService) Save(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return errors.New("cannot modify healthcheck config for databases without workspace")
|
||||
}
|
||||
|
||||
canManage, err := s.workspaceService.CanUserManageDBs(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !canManage {
|
||||
return errors.New("insufficient permissions to modify healthcheck config")
|
||||
}
|
||||
|
||||
healthcheckConfig := configDTO.ToDTO()
|
||||
@@ -60,6 +73,12 @@ func (s *HealthcheckConfigService) Save(
|
||||
}
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
fmt.Sprintf("Healthcheck config updated for database '%s'", database.Name),
|
||||
&user.ID,
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -72,8 +91,16 @@ func (s *HealthcheckConfigService) GetByDatabaseID(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if database.UserID != user.ID {
|
||||
return nil, errors.New("user does not have access to this database")
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot access healthcheck config for databases without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, &user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to view healthcheck config")
|
||||
}
|
||||
|
||||
config, err := s.healthcheckConfigRepository.GetByDatabaseID(database.ID)
|
||||
|
||||
@@ -64,6 +64,16 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
|
||||
}
|
||||
}
|
||||
|
||||
func (i *Interval) Copy() *Interval {
|
||||
return &Interval{
|
||||
ID: uuid.Nil,
|
||||
Interval: i.Interval,
|
||||
TimeOfDay: i.TimeOfDay,
|
||||
Weekday: i.Weekday,
|
||||
DayOfMonth: i.DayOfMonth,
|
||||
}
|
||||
}
|
||||
|
||||
// daily trigger: honour the TimeOfDay slot and catch up the previous one
|
||||
func (i *Interval) shouldTriggerDaily(now, lastBackup time.Time) bool {
|
||||
if i.TimeOfDay == nil {
|
||||
|
||||
@@ -2,15 +2,16 @@ package notifiers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/users"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type NotifierController struct {
|
||||
notifierService *NotifierService
|
||||
userService *users.UserService
|
||||
notifierService *NotifierService
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
}
|
||||
|
||||
func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
@@ -29,35 +30,40 @@ func (c *NotifierController) RegisterRoutes(router *gin.RouterGroup) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param notifier body Notifier true "Notifier data"
|
||||
// @Param request body Notifier true "Notifier data with workspaceId"
|
||||
// @Success 200 {object} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers [post]
|
||||
func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var notifier Notifier
|
||||
if err := ctx.ShouldBindJSON(¬ifier); err != nil {
|
||||
var request Notifier
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := notifier.Validate(); err != nil {
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SaveNotifier(user, request.WorkspaceID, &request); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SaveNotifier(user, ¬ifier); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx.JSON(http.StatusOK, notifier)
|
||||
ctx.JSON(http.StatusOK, request)
|
||||
}
|
||||
|
||||
// GetNotifier
|
||||
@@ -70,11 +76,12 @@ func (c *NotifierController) SaveNotifier(ctx *gin.Context) {
|
||||
// @Success 200 {object} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id} [get]
|
||||
func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -86,6 +93,10 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
|
||||
notifier, err := c.notifierService.GetNotifier(user, id)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -95,22 +106,41 @@ func (c *NotifierController) GetNotifier(ctx *gin.Context) {
|
||||
|
||||
// GetNotifiers
|
||||
// @Summary Get all notifiers
|
||||
// @Description Get all notifiers for the current user
|
||||
// @Description Get all notifiers for a workspace
|
||||
// @Tags notifiers
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param workspace_id query string true "Workspace ID"
|
||||
// @Success 200 {array} Notifier
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers [get]
|
||||
func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
notifiers, err := c.notifierService.GetNotifiers(user)
|
||||
workspaceIDStr := ctx.Query("workspace_id")
|
||||
if workspaceIDStr == "" {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspace_id query parameter is required"})
|
||||
return
|
||||
}
|
||||
|
||||
workspaceID, err := uuid.Parse(workspaceIDStr)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace_id"})
|
||||
return
|
||||
}
|
||||
|
||||
notifiers, err := c.notifierService.GetNotifiers(user, workspaceID)
|
||||
if err != nil {
|
||||
if err.Error() == "insufficient permissions to view notifiers in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -128,11 +158,12 @@ func (c *NotifierController) GetNotifiers(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id} [delete]
|
||||
func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,13 +173,11 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
notifier, err := c.notifierService.GetNotifier(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.DeleteNotifier(user, notifier.ID); err != nil {
|
||||
if err := c.notifierService.DeleteNotifier(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to manage notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -166,11 +195,12 @@ func (c *NotifierController) DeleteNotifier(ctx *gin.Context) {
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/{id}/test [post]
|
||||
func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -181,6 +211,10 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotification(user, id); err != nil {
|
||||
if err.Error() == "insufficient permissions to test notifier in this workspace" {
|
||||
ctx.JSON(http.StatusForbidden, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -195,28 +229,44 @@ func (c *NotifierController) SendTestNotification(ctx *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param Authorization header string true "JWT token"
|
||||
// @Param notifier body Notifier true "Notifier data"
|
||||
// @Param request body Notifier true "Notifier data with workspaceId"
|
||||
// @Success 200
|
||||
// @Failure 400
|
||||
// @Failure 401
|
||||
// @Failure 403
|
||||
// @Router /notifiers/direct-test [post]
|
||||
func (c *NotifierController) SendTestNotificationDirect(ctx *gin.Context) {
|
||||
user, err := c.userService.GetUserFromToken(ctx.GetHeader("Authorization"))
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
user, ok := users_middleware.GetUserFromContext(ctx)
|
||||
if !ok {
|
||||
ctx.JSON(http.StatusUnauthorized, gin.H{"error": "User not authenticated"})
|
||||
return
|
||||
}
|
||||
|
||||
var notifier Notifier
|
||||
if err := ctx.ShouldBindJSON(¬ifier); err != nil {
|
||||
var request Notifier
|
||||
if err := ctx.ShouldBindJSON(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// For direct test, associate with the current user
|
||||
notifier.UserID = user.ID
|
||||
if request.WorkspaceID == uuid.Nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": "workspaceId is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotificationToNotifier(¬ifier); err != nil {
|
||||
canView, _, err := c.workspaceService.CanUserAccessWorkspace(request.WorkspaceID, user)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !canView {
|
||||
ctx.JSON(
|
||||
http.StatusForbidden,
|
||||
gin.H{"error": "insufficient permissions to test notifier in this workspace"},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.notifierService.SendTestNotificationToNotifier(&request); err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
1041
backend/internal/features/notifiers/controller_test.go
Normal file
1041
backend/internal/features/notifiers/controller_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,9 @@
|
||||
package notifiers
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/users"
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -9,10 +11,13 @@ var notifierRepository = &NotifierRepository{}
|
||||
var notifierService = &NotifierService{
|
||||
notifierRepository,
|
||||
logger.GetLogger(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var notifierController = &NotifierController{
|
||||
notifierService,
|
||||
users.GetUserService(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
}
|
||||
|
||||
func GetNotifierController() *NotifierController {
|
||||
@@ -22,3 +27,10 @@ func GetNotifierController() *NotifierController {
|
||||
func GetNotifierService() *NotifierService {
|
||||
return notifierService
|
||||
}
|
||||
|
||||
func GetNotifierRepository() *NotifierRepository {
|
||||
return notifierRepository
|
||||
}
|
||||
func SetupDependencies() {
|
||||
workspaces_services.GetWorkspaceService().AddWorkspaceDeletionListener(notifierService)
|
||||
}
|
||||
|
||||
@@ -8,4 +8,5 @@ const (
|
||||
NotifierTypeWebhook NotifierType = "WEBHOOK"
|
||||
NotifierTypeSlack NotifierType = "SLACK"
|
||||
NotifierTypeDiscord NotifierType = "DISCORD"
|
||||
NotifierTypeTeams NotifierType = "TEAMS"
|
||||
)
|
||||
|
||||
@@ -1,9 +1,21 @@
|
||||
package notifiers
|
||||
|
||||
import "log/slog"
|
||||
import (
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
)
|
||||
|
||||
type NotificationSender interface {
|
||||
Send(logger *slog.Logger, heading string, message string) error
|
||||
Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error
|
||||
|
||||
Validate() error
|
||||
Validate(encryptor encryption.FieldEncryptor) error
|
||||
|
||||
HideSensitiveData()
|
||||
|
||||
EncryptSensitiveData(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
@@ -6,41 +6,49 @@ import (
|
||||
discord_notifier "postgresus-backend/internal/features/notifiers/models/discord"
|
||||
"postgresus-backend/internal/features/notifiers/models/email_notifier"
|
||||
slack_notifier "postgresus-backend/internal/features/notifiers/models/slack"
|
||||
teams_notifier "postgresus-backend/internal/features/notifiers/models/teams"
|
||||
telegram_notifier "postgresus-backend/internal/features/notifiers/models/telegram"
|
||||
webhook_notifier "postgresus-backend/internal/features/notifiers/models/webhook"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type Notifier struct {
|
||||
ID uuid.UUID `json:"id" gorm:"column:id;primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
UserID uuid.UUID `json:"userId" gorm:"column:user_id;not null;type:uuid;index"`
|
||||
WorkspaceID uuid.UUID `json:"workspaceId" gorm:"column:workspace_id;not null;type:uuid;index"`
|
||||
Name string `json:"name" gorm:"column:name;not null;type:varchar(255)"`
|
||||
NotifierType NotifierType `json:"notifierType" gorm:"column:notifier_type;not null;type:varchar(50)"`
|
||||
LastSendError *string `json:"lastSendError" gorm:"column:last_send_error;type:text"`
|
||||
|
||||
// specific notifier
|
||||
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
|
||||
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
|
||||
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
|
||||
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
|
||||
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
|
||||
TelegramNotifier *telegram_notifier.TelegramNotifier `json:"telegramNotifier" gorm:"foreignKey:NotifierID"`
|
||||
EmailNotifier *email_notifier.EmailNotifier `json:"emailNotifier" gorm:"foreignKey:NotifierID"`
|
||||
WebhookNotifier *webhook_notifier.WebhookNotifier `json:"webhookNotifier" gorm:"foreignKey:NotifierID"`
|
||||
SlackNotifier *slack_notifier.SlackNotifier `json:"slackNotifier" gorm:"foreignKey:NotifierID"`
|
||||
DiscordNotifier *discord_notifier.DiscordNotifier `json:"discordNotifier" gorm:"foreignKey:NotifierID"`
|
||||
TeamsNotifier *teams_notifier.TeamsNotifier `json:"teamsNotifier,omitempty" gorm:"foreignKey:NotifierID;constraint:OnDelete:CASCADE"`
|
||||
}
|
||||
|
||||
func (n *Notifier) TableName() string {
|
||||
return "notifiers"
|
||||
}
|
||||
|
||||
func (n *Notifier) Validate() error {
|
||||
func (n *Notifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
|
||||
return n.getSpecificNotifier().Validate()
|
||||
return n.getSpecificNotifier().Validate(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
err := n.getSpecificNotifier().Send(logger, heading, message)
|
||||
func (n *Notifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
err := n.getSpecificNotifier().Send(encryptor, logger, heading, message)
|
||||
|
||||
if err != nil {
|
||||
lastSendError := err.Error()
|
||||
@@ -52,6 +60,46 @@ func (n *Notifier) Send(logger *slog.Logger, heading string, message string) err
|
||||
return err
|
||||
}
|
||||
|
||||
func (n *Notifier) HideSensitiveData() {
|
||||
n.getSpecificNotifier().HideSensitiveData()
|
||||
}
|
||||
|
||||
func (n *Notifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
return n.getSpecificNotifier().EncryptSensitiveData(encryptor)
|
||||
}
|
||||
|
||||
func (n *Notifier) Update(incoming *Notifier) {
|
||||
n.Name = incoming.Name
|
||||
n.NotifierType = incoming.NotifierType
|
||||
|
||||
switch n.NotifierType {
|
||||
case NotifierTypeTelegram:
|
||||
if n.TelegramNotifier != nil && incoming.TelegramNotifier != nil {
|
||||
n.TelegramNotifier.Update(incoming.TelegramNotifier)
|
||||
}
|
||||
case NotifierTypeEmail:
|
||||
if n.EmailNotifier != nil && incoming.EmailNotifier != nil {
|
||||
n.EmailNotifier.Update(incoming.EmailNotifier)
|
||||
}
|
||||
case NotifierTypeWebhook:
|
||||
if n.WebhookNotifier != nil && incoming.WebhookNotifier != nil {
|
||||
n.WebhookNotifier.Update(incoming.WebhookNotifier)
|
||||
}
|
||||
case NotifierTypeSlack:
|
||||
if n.SlackNotifier != nil && incoming.SlackNotifier != nil {
|
||||
n.SlackNotifier.Update(incoming.SlackNotifier)
|
||||
}
|
||||
case NotifierTypeDiscord:
|
||||
if n.DiscordNotifier != nil && incoming.DiscordNotifier != nil {
|
||||
n.DiscordNotifier.Update(incoming.DiscordNotifier)
|
||||
}
|
||||
case NotifierTypeTeams:
|
||||
if n.TeamsNotifier != nil && incoming.TeamsNotifier != nil {
|
||||
n.TeamsNotifier.Update(incoming.TeamsNotifier)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Notifier) getSpecificNotifier() NotificationSender {
|
||||
switch n.NotifierType {
|
||||
case NotifierTypeTelegram:
|
||||
@@ -64,6 +112,8 @@ func (n *Notifier) getSpecificNotifier() NotificationSender {
|
||||
return n.SlackNotifier
|
||||
case NotifierTypeDiscord:
|
||||
return n.DiscordNotifier
|
||||
case NotifierTypeTeams:
|
||||
return n.TeamsNotifier
|
||||
default:
|
||||
panic("unknown notifier type: " + string(n.NotifierType))
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -21,7 +22,7 @@ func (d *DiscordNotifier) TableName() string {
|
||||
return "discord_notifiers"
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Validate() error {
|
||||
func (d *DiscordNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -29,7 +30,17 @@ func (d *DiscordNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (d *DiscordNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
@@ -44,7 +55,7 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal Discord payload: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", d.ChannelWebhookURL, bytes.NewReader(jsonPayload))
|
||||
req, err := http.NewRequest("POST", webhookURL, bytes.NewReader(jsonPayload))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -71,3 +82,24 @@ func (d *DiscordNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) HideSensitiveData() {
|
||||
d.ChannelWebhookURL = ""
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) Update(incoming *DiscordNotifier) {
|
||||
if incoming.ChannelWebhookURL != "" {
|
||||
d.ChannelWebhookURL = incoming.ChannelWebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DiscordNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if d.ChannelWebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(d.NotifierID, d.ChannelWebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
d.ChannelWebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -27,13 +28,14 @@ type EmailNotifier struct {
|
||||
SMTPPort int `json:"smtpPort" gorm:"not null;column:smtp_port"`
|
||||
SMTPUser string `json:"smtpUser" gorm:"type:varchar(255);column:smtp_user"`
|
||||
SMTPPassword string `json:"smtpPassword" gorm:"type:varchar(255);column:smtp_password"`
|
||||
From string `json:"from" gorm:"type:varchar(255);column:from_email"`
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) TableName() string {
|
||||
return "email_notifiers"
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Validate() error {
|
||||
func (e *EmailNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if e.TargetEmail == "" {
|
||||
return errors.New("target email is required")
|
||||
}
|
||||
@@ -54,11 +56,29 @@ func (e *EmailNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (e *EmailNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
// Decrypt SMTP password if provided
|
||||
var smtpPassword string
|
||||
if e.SMTPPassword != "" {
|
||||
decrypted, err := encryptor.Decrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt SMTP password: %w", err)
|
||||
}
|
||||
smtpPassword = decrypted
|
||||
}
|
||||
|
||||
// Compose email
|
||||
from := e.SMTPUser
|
||||
from := e.From
|
||||
if from == "" {
|
||||
from = "noreply@" + e.SMTPHost
|
||||
from = e.SMTPUser
|
||||
if from == "" {
|
||||
from = "noreply@" + e.SMTPHost
|
||||
}
|
||||
}
|
||||
|
||||
to := []string{e.TargetEmail}
|
||||
@@ -72,15 +92,16 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
)
|
||||
body := message
|
||||
fromHeader := fmt.Sprintf("From: %s\r\n", from)
|
||||
toHeader := fmt.Sprintf("To: %s\r\n", e.TargetEmail)
|
||||
|
||||
// Combine all parts of the email
|
||||
emailContent := []byte(fromHeader + subject + mime + body)
|
||||
emailContent := []byte(fromHeader + toHeader + subject + mime + body)
|
||||
|
||||
addr := net.JoinHostPort(e.SMTPHost, fmt.Sprintf("%d", e.SMTPPort))
|
||||
timeout := DefaultTimeout
|
||||
|
||||
// Determine if authentication is required
|
||||
isAuthRequired := e.SMTPUser != "" && e.SMTPPassword != ""
|
||||
isAuthRequired := e.SMTPUser != "" && smtpPassword != ""
|
||||
|
||||
// Handle different port scenarios
|
||||
if e.SMTPPort == ImplicitTLSPort {
|
||||
@@ -111,7 +132,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Set up authentication only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -174,7 +195,7 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
|
||||
// Authenticate only if credentials are provided
|
||||
if isAuthRequired {
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, e.SMTPPassword, e.SMTPHost)
|
||||
auth := smtp.PlainAuth("", e.SMTPUser, smtpPassword, e.SMTPHost)
|
||||
if err := client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP authentication failed: %w", err)
|
||||
}
|
||||
@@ -208,3 +229,30 @@ func (e *EmailNotifier) Send(logger *slog.Logger, heading string, message string
|
||||
return client.Quit()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) HideSensitiveData() {
|
||||
e.SMTPPassword = ""
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) Update(incoming *EmailNotifier) {
|
||||
e.TargetEmail = incoming.TargetEmail
|
||||
e.SMTPHost = incoming.SMTPHost
|
||||
e.SMTPPort = incoming.SMTPPort
|
||||
e.SMTPUser = incoming.SMTPUser
|
||||
e.From = incoming.From
|
||||
|
||||
if incoming.SMTPPassword != "" {
|
||||
e.SMTPPassword = incoming.SMTPPassword
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EmailNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if e.SMTPPassword != "" {
|
||||
encrypted, err := encryptor.Encrypt(e.NotifierID, e.SMTPPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt SMTP password: %w", err)
|
||||
}
|
||||
e.SMTPPassword = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -23,7 +24,7 @@ type SlackNotifier struct {
|
||||
|
||||
func (s *SlackNotifier) TableName() string { return "slack_notifiers" }
|
||||
|
||||
func (s *SlackNotifier) Validate() error {
|
||||
func (s *SlackNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -43,7 +44,16 @@ func (s *SlackNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error {
|
||||
func (s *SlackNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
full := fmt.Sprintf("*%s*", heading)
|
||||
|
||||
if message != "" {
|
||||
@@ -60,6 +70,7 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
maxAttempts = 5
|
||||
defaultBackoff = 2 * time.Second // when Retry-After header missing
|
||||
backoffMultiplier = 1.5 // use exponential growth
|
||||
requestTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -67,6 +78,10 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
attempts = 0
|
||||
)
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
}
|
||||
|
||||
for {
|
||||
attempts++
|
||||
|
||||
@@ -80,9 +95,9 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json; charset=utf-8")
|
||||
req.Header.Set("Authorization", "Bearer "+s.BotToken)
|
||||
req.Header.Set("Authorization", "Bearer "+botToken)
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("send slack message: %w", err)
|
||||
}
|
||||
@@ -132,3 +147,26 @@ func (s *SlackNotifier) Send(logger *slog.Logger, heading, message string) error
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) HideSensitiveData() {
|
||||
s.BotToken = ""
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) Update(incoming *SlackNotifier) {
|
||||
s.TargetChatID = incoming.TargetChatID
|
||||
|
||||
if incoming.BotToken != "" {
|
||||
s.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SlackNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if s.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(s.NotifierID, s.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
s.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
133
backend/internal/features/notifiers/models/teams/model.go
Normal file
133
backend/internal/features/notifiers/models/teams/model.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package teams_notifier
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TeamsNotifier struct {
|
||||
NotifierID uuid.UUID `gorm:"type:uuid;primaryKey;column:notifier_id" json:"notifierId"`
|
||||
WebhookURL string `gorm:"type:text;not null;column:power_automate_url" json:"powerAutomateUrl"`
|
||||
}
|
||||
|
||||
func (TeamsNotifier) TableName() string {
|
||||
return "teams_notifiers"
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL == "" {
|
||||
return errors.New("webhook_url is required")
|
||||
}
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
u, err := url.Parse(webhookURL)
|
||||
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
|
||||
return errors.New("invalid webhook_url")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type cardAttachment struct {
|
||||
ContentType string `json:"contentType"`
|
||||
Content any `json:"content"`
|
||||
}
|
||||
|
||||
type payload struct {
|
||||
Title string `json:"title"`
|
||||
Text string `json:"text"`
|
||||
Attachments []cardAttachment `json:"attachments,omitempty"`
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading, message string,
|
||||
) error {
|
||||
if err := n.Validate(encryptor); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
webhookURL, err := encryptor.Decrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
card := map[string]any{
|
||||
"type": "AdaptiveCard",
|
||||
"version": "1.4",
|
||||
"body": []any{
|
||||
map[string]any{
|
||||
"type": "TextBlock",
|
||||
"size": "Medium",
|
||||
"weight": "Bolder",
|
||||
"text": heading,
|
||||
},
|
||||
map[string]any{"type": "TextBlock", "wrap": true, "text": message},
|
||||
},
|
||||
}
|
||||
|
||||
p := payload{
|
||||
Title: heading,
|
||||
Text: message,
|
||||
Attachments: []cardAttachment{
|
||||
{ContentType: "application/vnd.microsoft.card.adaptive", Content: card},
|
||||
},
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(p)
|
||||
req, err := http.NewRequest(http.MethodPost, webhookURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
logger.Error("failed to close response body", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return fmt.Errorf("teams webhook returned status %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) HideSensitiveData() {
|
||||
n.WebhookURL = ""
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) Update(incoming *TeamsNotifier) {
|
||||
if incoming.WebhookURL != "" {
|
||||
n.WebhookURL = incoming.WebhookURL
|
||||
}
|
||||
}
|
||||
|
||||
func (n *TeamsNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if n.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(n.NotifierID, n.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
n.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -24,7 +25,7 @@ func (t *TelegramNotifier) TableName() string {
|
||||
return "telegram_notifiers"
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Validate() error {
|
||||
func (t *TelegramNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken == "" {
|
||||
return errors.New("bot token is required")
|
||||
}
|
||||
@@ -36,13 +37,23 @@ func (t *TelegramNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *TelegramNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
botToken, err := encryptor.Decrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt bot token: %w", err)
|
||||
}
|
||||
|
||||
fullMessage := heading
|
||||
if message != "" {
|
||||
fullMessage = fmt.Sprintf("%s\n\n%s", heading, message)
|
||||
}
|
||||
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", t.BotToken)
|
||||
apiURL := fmt.Sprintf("https://api.telegram.org/bot%s/sendMessage", botToken)
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("chat_id", t.TargetChatID)
|
||||
@@ -80,3 +91,27 @@ func (t *TelegramNotifier) Send(logger *slog.Logger, heading string, message str
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) HideSensitiveData() {
|
||||
t.BotToken = ""
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) Update(incoming *TelegramNotifier) {
|
||||
t.TargetChatID = incoming.TargetChatID
|
||||
t.ThreadID = incoming.ThreadID
|
||||
|
||||
if incoming.BotToken != "" {
|
||||
t.BotToken = incoming.BotToken
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TelegramNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.BotToken != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.BotToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt bot token: %w", err)
|
||||
}
|
||||
t.BotToken = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
@@ -23,7 +24,7 @@ func (t *WebhookNotifier) TableName() string {
|
||||
return "webhook_notifiers"
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Validate() error {
|
||||
func (t *WebhookNotifier) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL == "" {
|
||||
return errors.New("webhook URL is required")
|
||||
}
|
||||
@@ -35,11 +36,21 @@ func (t *WebhookNotifier) Validate() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message string) error {
|
||||
func (t *WebhookNotifier) Send(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
heading string,
|
||||
message string,
|
||||
) error {
|
||||
webhookURL, err := encryptor.Decrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt webhook URL: %w", err)
|
||||
}
|
||||
|
||||
switch t.WebhookMethod {
|
||||
case WebhookMethodGET:
|
||||
reqURL := fmt.Sprintf("%s?heading=%s&message=%s",
|
||||
t.WebhookURL,
|
||||
webhookURL,
|
||||
url.QueryEscape(heading),
|
||||
url.QueryEscape(message),
|
||||
)
|
||||
@@ -76,7 +87,7 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("failed to marshal webhook payload: %w", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(t.WebhookURL, "application/json", bytes.NewReader(body))
|
||||
resp, err := http.Post(webhookURL, "application/json", bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send POST webhook: %w", err)
|
||||
}
|
||||
@@ -102,3 +113,22 @@ func (t *WebhookNotifier) Send(logger *slog.Logger, heading string, message stri
|
||||
return fmt.Errorf("unsupported webhook method: %s", t.WebhookMethod)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) HideSensitiveData() {
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) Update(incoming *WebhookNotifier) {
|
||||
t.WebhookURL = incoming.WebhookURL
|
||||
t.WebhookMethod = incoming.WebhookMethod
|
||||
}
|
||||
|
||||
func (t *WebhookNotifier) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if t.WebhookURL != "" {
|
||||
encrypted, err := encryptor.Encrypt(t.NotifierID, t.WebhookURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt webhook URL: %w", err)
|
||||
}
|
||||
t.WebhookURL = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user