mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
Compare commits
63 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da9b279e8b | ||
|
|
7a5654a80a | ||
|
|
ff94e06306 | ||
|
|
3ae8761666 | ||
|
|
70e0a59a82 | ||
|
|
e1f466c965 | ||
|
|
a0f284e06b | ||
|
|
8638b2d136 | ||
|
|
16d4f506bc | ||
|
|
c100d94a92 | ||
|
|
f14739a1fb | ||
|
|
b7d2521088 | ||
|
|
eb8e5aa428 | ||
|
|
1f030bd8fb | ||
|
|
b278a79104 | ||
|
|
b74ae734af | ||
|
|
d21a9398c6 | ||
|
|
6ad7b95b7d | ||
|
|
8432d1626f | ||
|
|
d7f631fa93 | ||
|
|
c3fb2aa529 | ||
|
|
1817937409 | ||
|
|
3172396668 | ||
|
|
9cd5c8c57c | ||
|
|
d8826d85c3 | ||
|
|
49fdd46cbe | ||
|
|
c6261d434b | ||
|
|
918002acde | ||
|
|
c0721a43e1 | ||
|
|
461e15cd7a | ||
|
|
69a53936f5 | ||
|
|
2bafec3c19 | ||
|
|
422b44dfdc | ||
|
|
51d7fe54d0 | ||
|
|
6e2d63626c | ||
|
|
260c7a1188 | ||
|
|
ace94c144b | ||
|
|
b666cd9e2e | ||
|
|
9dac63430d | ||
|
|
8217906c7a | ||
|
|
db71a5ef7b | ||
|
|
df78e296b3 | ||
|
|
fda3bf9b98 | ||
|
|
e19f449c60 | ||
|
|
5944d7c4b6 | ||
|
|
1f5c9d3d01 | ||
|
|
d27b885fc1 | ||
|
|
45054bc4b5 | ||
|
|
09f27019e8 | ||
|
|
cba8fdf49c | ||
|
|
41c72cf7b6 | ||
|
|
f04a8b7a82 | ||
|
|
552167e4ef | ||
|
|
be42cfab1f | ||
|
|
ea34ced676 | ||
|
|
09cb1488b3 | ||
|
|
b6518ef667 | ||
|
|
25c58e6209 | ||
|
|
97ee4b55c2 | ||
|
|
12eea72392 | ||
|
|
75c88bac50 | ||
|
|
ff1b6536bf | ||
|
|
06197f986d |
68
.dockerignore
Normal file
68
.dockerignore
Normal file
@@ -0,0 +1,68 @@
|
||||
# Git and GitHub
|
||||
.git
|
||||
.gitignore
|
||||
.github
|
||||
|
||||
# Node modules everywhere
|
||||
node_modules
|
||||
**/node_modules
|
||||
|
||||
# Backend - exclude everything except what's needed for build
|
||||
backend/tools
|
||||
backend/mysqldata
|
||||
backend/pgdata
|
||||
backend/mariadbdata
|
||||
backend/temp
|
||||
backend/images
|
||||
backend/bin
|
||||
backend/*.exe
|
||||
|
||||
# Scripts and data directories
|
||||
scripts
|
||||
postgresus-data
|
||||
|
||||
# IDE and editor files
|
||||
.idea
|
||||
.vscode
|
||||
.cursor
|
||||
**/*.swp
|
||||
**/*.swo
|
||||
|
||||
# Documentation and articles (not needed for build)
|
||||
articles
|
||||
docs
|
||||
pages
|
||||
|
||||
# Notifiers not needed in container
|
||||
notifiers
|
||||
|
||||
# Dist (will be built fresh)
|
||||
frontend/dist
|
||||
|
||||
# Environment files (handled separately)
|
||||
.env.local
|
||||
.env.development
|
||||
|
||||
# Logs and temp files
|
||||
**/*.log
|
||||
tmp
|
||||
temp
|
||||
|
||||
# OS files
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Helm charts and deployment configs
|
||||
deploy
|
||||
|
||||
# License and other root files
|
||||
LICENSE
|
||||
CITATION.cff
|
||||
*.md
|
||||
assets
|
||||
|
||||
# Python cache
|
||||
**/__pycache__
|
||||
|
||||
# Pre-commit config
|
||||
.pre-commit-config.yaml
|
||||
102
.github/CODE_OF_CONDUCT.md
vendored
Normal file
102
.github/CODE_OF_CONDUCT.md
vendored
Normal file
@@ -0,0 +1,102 @@
|
||||
# Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors and maintainers pledge to make participation in the Postgresus community a friendly and welcoming experience for everyone, regardless of background, experience level or personal circumstances.
|
||||
|
||||
We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive and healthy community.
|
||||
|
||||
## Our Standards
|
||||
|
||||
### Examples of behavior that contributes to a positive environment
|
||||
|
||||
- Using welcoming and inclusive language
|
||||
- Being respectful of differing viewpoints and experiences
|
||||
- Gracefully accepting constructive criticism
|
||||
- Focusing on what is best for the community
|
||||
- Showing empathy towards other community members
|
||||
- Helping newcomers get started with contributions
|
||||
- Providing clear and constructive feedback on pull requests
|
||||
- Celebrating successes and acknowledging contributions
|
||||
|
||||
### Examples of unacceptable behavior
|
||||
|
||||
- Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
- Publishing others' private information, such as physical or email addresses, without their explicit permission
|
||||
- Spam, self-promotion or off-topic content in project spaces
|
||||
- Other conduct which could reasonably be considered inappropriate in a professional setting
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, including:
|
||||
|
||||
- GitHub repositories (issues, pull requests, discussions, comments)
|
||||
- Telegram channels and direct messages related to Postgresus
|
||||
- Social media interactions when representing the project
|
||||
- Community forums and online discussions
|
||||
- Any other spaces where Postgresus community members interact
|
||||
|
||||
This Code of Conduct also applies when an individual is officially representing the community in public spaces, such as using an official email address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive or unacceptable behavior may be reported to the community leaders responsible for enforcement:
|
||||
|
||||
- **Email**: [info@postgresus.com](mailto:info@postgresus.com)
|
||||
- **Telegram**: [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
All complaints will be reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period of time. This includes avoiding interactions in community spaces as well as external channels like social media. Violating these terms may lead to a temporary or permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within the community.
|
||||
|
||||
## Contributing with Respect
|
||||
|
||||
When contributing to Postgresus, please:
|
||||
|
||||
- Be patient with maintainers and other contributors
|
||||
- Understand that everyone has different levels of experience
|
||||
- Ask questions in a respectful manner
|
||||
- Accept that your contribution may not be accepted, and be open to feedback
|
||||
- Follow the [contribution guidelines](https://postgresus.com/contribute)
|
||||
|
||||
For code contributions, remember to:
|
||||
|
||||
- Discuss significant changes before implementing them
|
||||
- Be open to code review feedback
|
||||
- Help review others' contributions when possible
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.1, available at [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html).
|
||||
|
||||
Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
|
||||
|
||||
For answers to common questions about this code of conduct, see the FAQ at [https://www.contributor-covenant.org/faq](https://www.contributor-covenant.org/faq).
|
||||
54
.github/SECURITY.md
vendored
Normal file
54
.github/SECURITY.md
vendored
Normal file
@@ -0,0 +1,54 @@
|
||||
# Security Policy
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
If you discover a security vulnerability in Postgresus, please report it responsibly. **Do not create a public GitHub issue for security vulnerabilities.**
|
||||
|
||||
### How to Report
|
||||
|
||||
1. **Email** (preferred): Send details to [info@postgresus.com](mailto:info@postgresus.com)
|
||||
2. **Telegram**: Contact [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
3. **GitHub Security Advisories**: Use the [private vulnerability reporting](https://github.com/RostislavDugin/postgresus/security/advisories/new) feature
|
||||
|
||||
### What to Include
|
||||
|
||||
- Description of the vulnerability
|
||||
- Steps to reproduce the issue
|
||||
- Potential impact and severity assessment
|
||||
- Any suggested fixes (optional)
|
||||
|
||||
## Supported Versions
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | --------- |
|
||||
| Latest | Yes |
|
||||
|
||||
We recommend always using the latest version of Postgresus. Security patches are applied to the most recent release.
|
||||
|
||||
### PostgreSQL Compatibility
|
||||
|
||||
Postgresus supports PostgreSQL versions 12, 13, 14, 15, 16, 17 and 18.
|
||||
|
||||
## Response Timeline
|
||||
|
||||
- **Acknowledgment**: Within 48-72 hours
|
||||
- **Initial Assessment**: Within 1 week
|
||||
- **Fix Timeline**: Depends on severity, but we aim to address critical issues as quickly as possible
|
||||
|
||||
We follow a coordinated disclosure policy. We ask that you give us reasonable time to address the vulnerability before any public disclosure.
|
||||
|
||||
## Security Features
|
||||
|
||||
Postgresus is designed with security in mind. For full details, see our [security documentation](https://postgresus.com/security).
|
||||
|
||||
Key features include:
|
||||
|
||||
- **AES-256-GCM Encryption**: Enterprise-grade encryption for backup files and sensitive data
|
||||
- **Read-Only Database Access**: Postgresus uses read-only access by default and warns if write permissions are detected
|
||||
- **Role-Based Access Control**: Assign viewer, member, admin or owner roles within workspaces
|
||||
- **Audit Logging**: Track all system activities and changes made by users
|
||||
- **Zero-Trust Storage**: Encrypted backups are safe even in shared cloud storage
|
||||
|
||||
## License
|
||||
|
||||
Postgresus is licensed under [Apache 2.0](../LICENSE).
|
||||
228
.github/workflows/ci-release.yml
vendored
228
.github/workflows/ci-release.yml
vendored
@@ -17,7 +17,7 @@ jobs:
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.3"
|
||||
go-version: "1.24.4"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
|
||||
- name: Install golangci-lint
|
||||
run: |
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin v1.60.3
|
||||
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/HEAD/install.sh | sh -s -- -b $(go env GOPATH)/bin v2.7.2
|
||||
echo "$(go env GOPATH)/bin" >> $GITHUB_PATH
|
||||
|
||||
- name: Install swag for swagger generation
|
||||
@@ -82,17 +82,59 @@ jobs:
|
||||
cd frontend
|
||||
npm run lint
|
||||
|
||||
test-frontend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-frontend]
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Node.js
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: "20"
|
||||
cache: "npm"
|
||||
cache-dependency-path: frontend/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
cd frontend
|
||||
npm ci
|
||||
|
||||
- name: Run frontend tests
|
||||
run: |
|
||||
cd frontend
|
||||
npm run test
|
||||
|
||||
test-backend:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint-backend]
|
||||
steps:
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Disk space before cleanup:"
|
||||
df -h
|
||||
# Remove unnecessary pre-installed software
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /usr/local/share/boost
|
||||
sudo rm -rf /usr/share/swift
|
||||
# Clean apt cache
|
||||
sudo apt-get clean
|
||||
# Clean docker images (if any pre-installed)
|
||||
docker system prune -af --volumes || true
|
||||
echo "Disk space after cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: "1.23.3"
|
||||
go-version: "1.24.4"
|
||||
|
||||
- name: Cache Go modules
|
||||
uses: actions/cache@v4
|
||||
@@ -141,9 +183,35 @@ jobs:
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# testing NAS
|
||||
TEST_NAS_PORT=7006
|
||||
# testing FTP
|
||||
TEST_FTP_PORT=7007
|
||||
# testing SFTP
|
||||
TEST_SFTP_PORT=7008
|
||||
# testing MySQL
|
||||
TEST_MYSQL_57_PORT=33057
|
||||
TEST_MYSQL_80_PORT=33080
|
||||
TEST_MYSQL_84_PORT=33084
|
||||
# testing MariaDB
|
||||
TEST_MARIADB_55_PORT=33055
|
||||
TEST_MARIADB_101_PORT=33101
|
||||
TEST_MARIADB_102_PORT=33102
|
||||
TEST_MARIADB_103_PORT=33103
|
||||
TEST_MARIADB_104_PORT=33104
|
||||
TEST_MARIADB_105_PORT=33105
|
||||
TEST_MARIADB_106_PORT=33106
|
||||
TEST_MARIADB_1011_PORT=33111
|
||||
TEST_MARIADB_114_PORT=33114
|
||||
TEST_MARIADB_118_PORT=33118
|
||||
TEST_MARIADB_120_PORT=33120
|
||||
# testing Telegram
|
||||
TEST_TELEGRAM_BOT_TOKEN=${{ secrets.TEST_TELEGRAM_BOT_TOKEN }}
|
||||
TEST_TELEGRAM_CHAT_ID=${{ secrets.TEST_TELEGRAM_CHAT_ID }}
|
||||
# supabase
|
||||
TEST_SUPABASE_HOST=${{ secrets.TEST_SUPABASE_HOST }}
|
||||
TEST_SUPABASE_PORT=${{ secrets.TEST_SUPABASE_PORT }}
|
||||
TEST_SUPABASE_USERNAME=${{ secrets.TEST_SUPABASE_USERNAME }}
|
||||
TEST_SUPABASE_PASSWORD=${{ secrets.TEST_SUPABASE_PASSWORD }}
|
||||
TEST_SUPABASE_DATABASE=${{ secrets.TEST_SUPABASE_DATABASE }}
|
||||
EOF
|
||||
|
||||
- name: Start test containers
|
||||
@@ -170,6 +238,44 @@ jobs:
|
||||
# Wait for Azurite
|
||||
timeout 60 bash -c 'until nc -z localhost 10000; do sleep 2; done'
|
||||
|
||||
# Wait for FTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7007; do sleep 2; done'
|
||||
|
||||
# Wait for SFTP
|
||||
timeout 60 bash -c 'until nc -z localhost 7008; do sleep 2; done'
|
||||
|
||||
# Wait for MySQL containers
|
||||
echo "Waiting for MySQL 5.7..."
|
||||
timeout 120 bash -c 'until docker exec test-mysql-57 mysqladmin ping -h localhost -u root -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MySQL 8.0..."
|
||||
timeout 120 bash -c 'until docker exec test-mysql-80 mysqladmin ping -h localhost -u root -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MySQL 8.4..."
|
||||
timeout 120 bash -c 'until docker exec test-mysql-84 mysqladmin ping -h localhost -u root -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
|
||||
# Wait for MariaDB containers
|
||||
echo "Waiting for MariaDB 5.5..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-55 mysqladmin ping -h localhost -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.1..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-101 mysqladmin ping -h localhost -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.2..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-102 mysqladmin ping -h localhost -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.3..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-103 mysqladmin ping -h localhost -prootpassword --silent 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.4..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-104 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.5..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-105 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.6..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-106 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 10.11..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-1011 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 11.4..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-114 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 11.8..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-118 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
echo "Waiting for MariaDB 12.0..."
|
||||
timeout 120 bash -c 'until docker exec test-mariadb-120 healthcheck.sh --connect --innodb_initialized 2>/dev/null; do sleep 2; done'
|
||||
|
||||
- name: Create data and temp directories
|
||||
run: |
|
||||
# Create directories that are used for backups and restore
|
||||
@@ -177,12 +283,77 @@ jobs:
|
||||
mkdir -p postgresus-data/backups
|
||||
mkdir -p postgresus-data/temp
|
||||
|
||||
- name: Install PostgreSQL client tools
|
||||
- name: Cache PostgreSQL client tools
|
||||
id: cache-postgres
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: /usr/lib/postgresql
|
||||
key: postgres-clients-12-18-v1
|
||||
|
||||
- name: Cache MySQL client tools
|
||||
id: cache-mysql
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mysql
|
||||
key: mysql-clients-57-80-84-v1
|
||||
|
||||
- name: Cache MariaDB client tools
|
||||
id: cache-mariadb
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: backend/tools/mariadb
|
||||
key: mariadb-clients-106-121-v1
|
||||
|
||||
- name: Install MySQL dependencies
|
||||
run: |
|
||||
sudo apt-get update -qq
|
||||
sudo apt-get install -y -qq libncurses6
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libncurses.so.6 /usr/lib/x86_64-linux-gnu/libncurses.so.5
|
||||
sudo ln -sf /usr/lib/x86_64-linux-gnu/libtinfo.so.6 /usr/lib/x86_64-linux-gnu/libtinfo.so.5
|
||||
|
||||
- name: Install PostgreSQL, MySQL and MariaDB client tools
|
||||
if: steps.cache-postgres.outputs.cache-hit != 'true' || steps.cache-mysql.outputs.cache-hit != 'true' || steps.cache-mariadb.outputs.cache-hit != 'true'
|
||||
run: |
|
||||
chmod +x backend/tools/download_linux.sh
|
||||
cd backend/tools
|
||||
./download_linux.sh
|
||||
|
||||
- name: Setup PostgreSQL symlinks (when using cache)
|
||||
if: steps.cache-postgres.outputs.cache-hit == 'true'
|
||||
run: |
|
||||
cd backend/tools
|
||||
mkdir -p postgresql
|
||||
for version in 12 13 14 15 16 17 18; do
|
||||
version_dir="postgresql/postgresql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
pg_bin_dir="/usr/lib/postgresql/$version/bin"
|
||||
if [ -d "$pg_bin_dir" ]; then
|
||||
ln -sf "$pg_bin_dir/pg_dump" "$version_dir/bin/pg_dump"
|
||||
ln -sf "$pg_bin_dir/pg_dumpall" "$version_dir/bin/pg_dumpall"
|
||||
ln -sf "$pg_bin_dir/psql" "$version_dir/bin/psql"
|
||||
ln -sf "$pg_bin_dir/pg_restore" "$version_dir/bin/pg_restore"
|
||||
ln -sf "$pg_bin_dir/createdb" "$version_dir/bin/createdb"
|
||||
ln -sf "$pg_bin_dir/dropdb" "$version_dir/bin/dropdb"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Verify MariaDB client tools exist
|
||||
run: |
|
||||
cd backend/tools
|
||||
echo "Checking MariaDB client tools..."
|
||||
if [ -f "mariadb/mariadb-10.6/bin/mariadb-dump" ]; then
|
||||
echo "MariaDB 10.6 client tools found"
|
||||
ls -la mariadb/mariadb-10.6/bin/
|
||||
else
|
||||
echo "MariaDB 10.6 client tools NOT found"
|
||||
fi
|
||||
if [ -f "mariadb/mariadb-12.1/bin/mariadb-dump" ]; then
|
||||
echo "MariaDB 12.1 client tools found"
|
||||
ls -la mariadb/mariadb-12.1/bin/
|
||||
else
|
||||
echo "MariaDB 12.1 client tools NOT found"
|
||||
fi
|
||||
|
||||
- name: Run database migrations
|
||||
run: |
|
||||
cd backend
|
||||
@@ -202,7 +373,7 @@ jobs:
|
||||
|
||||
determine-version:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, lint-frontend]
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
outputs:
|
||||
should_release: ${{ steps.version_bump.outputs.should_release }}
|
||||
@@ -295,7 +466,7 @@ jobs:
|
||||
|
||||
build-only:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [test-backend, lint-frontend]
|
||||
needs: [test-backend, test-frontend]
|
||||
if: ${{ github.ref == 'refs/heads/main' && contains(github.event.head_commit.message, '[skip-release]') }}
|
||||
steps:
|
||||
- name: Check out code
|
||||
@@ -455,6 +626,17 @@ jobs:
|
||||
echo EOF
|
||||
} >> $GITHUB_OUTPUT
|
||||
|
||||
- name: Update CITATION.cff version
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
sed -i "s/^version: .*/version: ${VERSION}/" CITATION.cff
|
||||
sed -i "s/^date-released: .*/date-released: \"$(date +%Y-%m-%d)\"/" CITATION.cff
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
git add CITATION.cff
|
||||
git commit -m "Update CITATION.cff to v${VERSION}" || true
|
||||
git push || true
|
||||
|
||||
- name: Create GitHub Release
|
||||
uses: actions/create-release@v1
|
||||
env:
|
||||
@@ -465,3 +647,37 @@ jobs:
|
||||
body: ${{ steps.changelog.outputs.changelog }}
|
||||
draft: false
|
||||
prerelease: false
|
||||
|
||||
publish-helm-chart:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [determine-version, build-and-push]
|
||||
if: ${{ needs.determine-version.outputs.should_release == 'true' }}
|
||||
permissions:
|
||||
contents: read
|
||||
packages: write
|
||||
steps:
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.14.0
|
||||
|
||||
- name: Log in to GHCR
|
||||
run: echo "${{ secrets.GITHUB_TOKEN }}" | helm registry login ghcr.io -u ${{ github.actor }} --password-stdin
|
||||
|
||||
- name: Update Chart.yaml with release version
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
sed -i "s/^version: .*/version: ${VERSION}/" deploy/helm/Chart.yaml
|
||||
sed -i "s/^appVersion: .*/appVersion: \"v${VERSION}\"/" deploy/helm/Chart.yaml
|
||||
cat deploy/helm/Chart.yaml
|
||||
|
||||
- name: Package Helm chart
|
||||
run: helm package deploy/helm --destination .
|
||||
|
||||
- name: Push Helm chart to GHCR
|
||||
run: |
|
||||
VERSION="${{ needs.determine-version.outputs.new_version }}"
|
||||
helm push postgresus-${VERSION}.tgz oci://ghcr.io/rostislavdugin/charts
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -5,4 +5,7 @@ pgdata/
|
||||
docker-compose.yml
|
||||
node_modules/
|
||||
.idea
|
||||
/articles
|
||||
/articles
|
||||
|
||||
.DS_Store
|
||||
/scripts
|
||||
33
CITATION.cff
Normal file
33
CITATION.cff
Normal file
@@ -0,0 +1,33 @@
|
||||
cff-version: 1.2.0
|
||||
title: Postgresus
|
||||
message: "If you use this software, please cite it as below."
|
||||
type: software
|
||||
authors:
|
||||
- family-names: Dugin
|
||||
given-names: Rostislav
|
||||
repository-code: https://github.com/RostislavDugin/postgresus
|
||||
url: https://postgresus.com
|
||||
abstract: "Free, open source and self-hosted solution for automated PostgreSQL backups with multiple storage options and notifications."
|
||||
keywords:
|
||||
- docker
|
||||
- kubernetes
|
||||
- golang
|
||||
- backups
|
||||
- postgres
|
||||
- devops
|
||||
- backup
|
||||
- database
|
||||
- tools
|
||||
- monitoring
|
||||
- ftp
|
||||
- postgresql
|
||||
- s3
|
||||
- psql
|
||||
- web-ui
|
||||
- self-hosted
|
||||
- pg
|
||||
- system-administration
|
||||
- database-backup
|
||||
license: Apache-2.0
|
||||
version: 2.12.0
|
||||
date-released: "2025-12-21"
|
||||
90
Dockerfile
90
Dockerfile
@@ -22,7 +22,7 @@ RUN npm run build
|
||||
|
||||
# ========= BUILD BACKEND =========
|
||||
# Backend build stage
|
||||
FROM --platform=$BUILDPLATFORM golang:1.23.3 AS backend-build
|
||||
FROM --platform=$BUILDPLATFORM golang:1.24.4 AS backend-build
|
||||
|
||||
# Make TARGET args available early so tools built here match the final image arch
|
||||
ARG TARGETOS
|
||||
@@ -77,16 +77,98 @@ ENV APP_VERSION=$APP_VERSION
|
||||
# Set production mode for Docker containers
|
||||
ENV ENV_MODE=production
|
||||
|
||||
# Install PostgreSQL server and client tools (versions 12-18)
|
||||
# Install PostgreSQL server and client tools (versions 12-18), MySQL client tools (5.7, 8.0, 8.4), MariaDB client tools, and rclone
|
||||
# Note: MySQL 5.7 is only available for x86_64, MySQL 8.0+ supports both x86_64 and ARM64
|
||||
# Note: MySQL binaries require libncurses5 for terminal handling
|
||||
# Note: MariaDB uses a single client version (12.1) that is backward compatible with all server versions
|
||||
ARG TARGETARCH
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
wget ca-certificates gnupg lsb-release sudo gosu && \
|
||||
wget ca-certificates gnupg lsb-release sudo gosu curl unzip xz-utils libncurses5 && \
|
||||
# Add PostgreSQL repository
|
||||
wget -qO- https://www.postgresql.org/media/keys/ACCC4CF8.asc | apt-key add - && \
|
||||
echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \
|
||||
> /etc/apt/sources.list.d/pgdg.list && \
|
||||
apt-get update && \
|
||||
# Install PostgreSQL
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-17 postgresql-18 postgresql-client-12 postgresql-client-13 postgresql-client-14 postgresql-client-15 \
|
||||
postgresql-client-16 postgresql-client-17 postgresql-client-18 && \
|
||||
postgresql-client-16 postgresql-client-17 postgresql-client-18 rclone && \
|
||||
# Create MySQL directories
|
||||
mkdir -p /usr/local/mysql-5.7/bin /usr/local/mysql-8.0/bin /usr/local/mysql-8.4/bin && \
|
||||
# Download and install MySQL client tools (architecture-aware)
|
||||
# MySQL 5.7: Only available for x86_64
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://dev.mysql.com/get/Downloads/MySQL-5.7/mysql-5.7.44-linux-glibc2.12-x86_64.tar.gz -O /tmp/mysql57.tar.gz && \
|
||||
tar -xzf /tmp/mysql57.tar.gz -C /tmp && \
|
||||
cp /tmp/mysql-5.7.*/bin/mysql /usr/local/mysql-5.7/bin/ && \
|
||||
cp /tmp/mysql-5.7.*/bin/mysqldump /usr/local/mysql-5.7/bin/ && \
|
||||
rm -rf /tmp/mysql-5.7.* /tmp/mysql57.tar.gz; \
|
||||
else \
|
||||
echo "MySQL 5.7 not available for $TARGETARCH, skipping..."; \
|
||||
fi && \
|
||||
# MySQL 8.0: Available for both x86_64 and ARM64
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://dev.mysql.com/get/Downloads/MySQL-8.0/mysql-8.0.40-linux-glibc2.17-x86_64-minimal.tar.xz -O /tmp/mysql80.tar.xz; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
wget -q https://dev.mysql.com/get/Downloads/MySQL-8.0/mysql-8.0.40-linux-glibc2.17-aarch64-minimal.tar.xz -O /tmp/mysql80.tar.xz; \
|
||||
fi && \
|
||||
tar -xJf /tmp/mysql80.tar.xz -C /tmp && \
|
||||
cp /tmp/mysql-8.0.*/bin/mysql /usr/local/mysql-8.0/bin/ && \
|
||||
cp /tmp/mysql-8.0.*/bin/mysqldump /usr/local/mysql-8.0/bin/ && \
|
||||
rm -rf /tmp/mysql-8.0.* /tmp/mysql80.tar.xz && \
|
||||
# MySQL 8.4: Available for both x86_64 and ARM64
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://dev.mysql.com/get/Downloads/MySQL-8.4/mysql-8.4.3-linux-glibc2.17-x86_64-minimal.tar.xz -O /tmp/mysql84.tar.xz; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
wget -q https://dev.mysql.com/get/Downloads/MySQL-8.4/mysql-8.4.3-linux-glibc2.17-aarch64-minimal.tar.xz -O /tmp/mysql84.tar.xz; \
|
||||
fi && \
|
||||
tar -xJf /tmp/mysql84.tar.xz -C /tmp && \
|
||||
cp /tmp/mysql-8.4.*/bin/mysql /usr/local/mysql-8.4/bin/ && \
|
||||
cp /tmp/mysql-8.4.*/bin/mysqldump /usr/local/mysql-8.4/bin/ && \
|
||||
rm -rf /tmp/mysql-8.4.* /tmp/mysql84.tar.xz && \
|
||||
# Make MySQL binaries executable (ignore errors for empty dirs on ARM64)
|
||||
chmod +x /usr/local/mysql-*/bin/* 2>/dev/null || true && \
|
||||
# Create MariaDB directories for both versions
|
||||
# MariaDB uses two client versions:
|
||||
# - 10.6 (legacy): For older servers (5.5, 10.1) that don't have generation_expression column
|
||||
# - 12.1 (modern): For newer servers (10.2+)
|
||||
mkdir -p /usr/local/mariadb-10.6/bin /usr/local/mariadb-12.1/bin && \
|
||||
# Download and install MariaDB 10.6 client tools (legacy - for older servers)
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://archive.mariadb.org/mariadb-10.6.21/bintar-linux-systemd-x86_64/mariadb-10.6.21-linux-systemd-x86_64.tar.gz -O /tmp/mariadb106.tar.gz && \
|
||||
tar -xzf /tmp/mariadb106.tar.gz -C /tmp && \
|
||||
cp /tmp/mariadb-10.6.*/bin/mariadb /usr/local/mariadb-10.6/bin/ && \
|
||||
cp /tmp/mariadb-10.6.*/bin/mariadb-dump /usr/local/mariadb-10.6/bin/ && \
|
||||
rm -rf /tmp/mariadb-10.6.* /tmp/mariadb106.tar.gz; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
# For ARM64, install MariaDB 10.6 client from official repository
|
||||
curl -fsSL https://mariadb.org/mariadb_release_signing_key.asc | gpg --dearmor -o /usr/share/keyrings/mariadb-keyring.gpg && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/mariadb-keyring.gpg] https://mirror.mariadb.org/repo/10.6/debian $(lsb_release -cs) main" > /etc/apt/sources.list.d/mariadb106.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends mariadb-client && \
|
||||
cp /usr/bin/mariadb /usr/local/mariadb-10.6/bin/mariadb && \
|
||||
cp /usr/bin/mariadb-dump /usr/local/mariadb-10.6/bin/mariadb-dump && \
|
||||
apt-get remove -y mariadb-client && \
|
||||
rm /etc/apt/sources.list.d/mariadb106.list; \
|
||||
fi && \
|
||||
# Download and install MariaDB 12.1 client tools (modern - for newer servers)
|
||||
if [ "$TARGETARCH" = "amd64" ]; then \
|
||||
wget -q https://archive.mariadb.org/mariadb-12.1.2/bintar-linux-systemd-x86_64/mariadb-12.1.2-linux-systemd-x86_64.tar.gz -O /tmp/mariadb121.tar.gz && \
|
||||
tar -xzf /tmp/mariadb121.tar.gz -C /tmp && \
|
||||
cp /tmp/mariadb-12.1.*/bin/mariadb /usr/local/mariadb-12.1/bin/ && \
|
||||
cp /tmp/mariadb-12.1.*/bin/mariadb-dump /usr/local/mariadb-12.1/bin/ && \
|
||||
rm -rf /tmp/mariadb-12.1.* /tmp/mariadb121.tar.gz; \
|
||||
elif [ "$TARGETARCH" = "arm64" ]; then \
|
||||
# For ARM64, install MariaDB 12.1 client from official repository
|
||||
echo "deb [signed-by=/usr/share/keyrings/mariadb-keyring.gpg] https://mirror.mariadb.org/repo/12.1/debian $(lsb_release -cs) main" > /etc/apt/sources.list.d/mariadb121.list && \
|
||||
apt-get update && \
|
||||
apt-get install -y --no-install-recommends mariadb-client && \
|
||||
cp /usr/bin/mariadb /usr/local/mariadb-12.1/bin/mariadb && \
|
||||
cp /usr/bin/mariadb-dump /usr/local/mariadb-12.1/bin/mariadb-dump; \
|
||||
fi && \
|
||||
# Make MariaDB binaries executable
|
||||
chmod +x /usr/local/mariadb-*/bin/* 2>/dev/null || true && \
|
||||
# Cleanup
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create postgres user and set up directories
|
||||
|
||||
78
README.md
78
README.md
@@ -36,25 +36,25 @@
|
||||
|
||||
## ✨ Features
|
||||
|
||||
### 🔄 **Scheduled Backups**
|
||||
### 🔄 **Scheduled backups**
|
||||
|
||||
- **Flexible scheduling**: hourly, daily, weekly, monthly
|
||||
- **Flexible scheduling**: hourly, daily, weekly, monthly or cron
|
||||
- **Precise timing**: run backups at specific times (e.g., 4 AM during low traffic)
|
||||
- **Smart compression**: 4-8x space savings with balanced compression (~20% overhead)
|
||||
|
||||
### 🗄️ **Multiple Storage Destinations** <a href="https://postgresus.com/storages">(view supported)</a>
|
||||
### 🗄️ **Multiple storage destinations** <a href="https://postgresus.com/storages">(view supported)</a>
|
||||
|
||||
- **Local storage**: Keep backups on your VPS/server
|
||||
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox and more
|
||||
- **Cloud storage**: S3, Cloudflare R2, Google Drive, NAS, Dropbox, SFTP, Rclone and more
|
||||
- **Secure**: All data stays under your control
|
||||
|
||||
### 📱 **Smart Notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
|
||||
### 📱 **Smart notifications** <a href="https://postgresus.com/notifiers">(view supported)</a>
|
||||
|
||||
- **Multiple channels**: Email, Telegram, Slack, Discord, webhooks
|
||||
- **Real-time updates**: Success and failure notifications
|
||||
- **Team integration**: Perfect for DevOps workflows
|
||||
|
||||
### 🐘 **PostgreSQL Support**
|
||||
### 🐘 **PostgreSQL support**
|
||||
|
||||
- **Multiple versions**: PostgreSQL 12, 13, 14, 15, 16, 17 and 18
|
||||
- **SSL support**: Secure connections available
|
||||
@@ -67,7 +67,7 @@
|
||||
- **Encryption for secrets**: Any sensitive data is encrypted and never exposed, even in logs or error messages
|
||||
- **Read-only user**: Postgresus uses by default a read-only user for backups and never stores anything that can change your data
|
||||
|
||||
### 👥 **Suitable for Teams** <a href="https://postgresus.com/access-management">(docs)</a>
|
||||
### 👥 **Suitable for teams** <a href="https://postgresus.com/access-management">(docs)</a>
|
||||
|
||||
- **Workspaces**: Group databases, notifiers and storages for different projects or teams
|
||||
- **Access management**: Control who can view or manage specific databases with role-based permissions
|
||||
@@ -80,7 +80,16 @@
|
||||
- **Dark & light themes**: Choose the look that suits your workflow
|
||||
- **Mobile adaptive**: Check your backups from anywhere on any device
|
||||
|
||||
### 🐳 **Self-Hosted & Secure**
|
||||
### ☁️ **Works with self-hosted & cloud databases**
|
||||
|
||||
Postgresus works seamlessly with both self-hosted PostgreSQL and cloud-managed databases:
|
||||
|
||||
- **Cloud support**: AWS RDS, Google Cloud SQL, Azure Database for PostgreSQL
|
||||
- **Self-hosted**: Any PostgreSQL instance you manage yourself
|
||||
- **Why no PITR?**: Cloud providers already offer native PITR, and external PITR backups cannot be restored to managed cloud databases — making them impractical for cloud-hosted PostgreSQL
|
||||
- **Practical granularity**: Hourly and daily backups are sufficient for 99% of projects without the operational complexity of WAL archiving
|
||||
|
||||
### 🐳 **Self-hosted & secure**
|
||||
|
||||
- **Docker-based**: Easy deployment and management
|
||||
- **Privacy-first**: All your data stays on your infrastructure
|
||||
@@ -88,7 +97,7 @@
|
||||
|
||||
### 📦 Installation <a href="https://postgresus.com/installation">(docs)</a>
|
||||
|
||||
You have three ways to install Postgresus:
|
||||
You have several ways to install Postgresus:
|
||||
|
||||
- Script (recommended)
|
||||
- Simple Docker run
|
||||
@@ -102,11 +111,11 @@ You have three ways to install Postgresus:
|
||||
|
||||
You have three ways to install Postgresus: automated script (recommended), simple Docker run, or Docker Compose setup.
|
||||
|
||||
### Option 1: Automated Installation Script (Recommended, Linux only)
|
||||
### Option 1: Automated installation script (recommended, Linux only)
|
||||
|
||||
The installation script will:
|
||||
|
||||
- ✅ Install Docker with Docker Compose(if not already installed)
|
||||
- ✅ Install Docker with Docker Compose (if not already installed)
|
||||
- ✅ Set up Postgresus
|
||||
- ✅ Configure automatic startup on system reboot
|
||||
|
||||
@@ -116,7 +125,7 @@ sudo curl -sSL https://raw.githubusercontent.com/RostislavDugin/postgresus/refs/
|
||||
| sudo bash
|
||||
```
|
||||
|
||||
### Option 2: Simple Docker Run
|
||||
### Option 2: Simple Docker run
|
||||
|
||||
The easiest way to run Postgresus with embedded PostgreSQL:
|
||||
|
||||
@@ -135,7 +144,7 @@ This single command will:
|
||||
- ✅ Store all data in `./postgresus-data` directory
|
||||
- ✅ Automatically restart on system reboot
|
||||
|
||||
### Option 3: Docker Compose Setup
|
||||
### Option 3: Docker Compose setup
|
||||
|
||||
Create a `docker-compose.yml` file with the following configuration:
|
||||
|
||||
@@ -159,32 +168,43 @@ docker compose up -d
|
||||
|
||||
### Option 4: Kubernetes with Helm
|
||||
|
||||
For Kubernetes deployments, use the official Helm chart.
|
||||
For Kubernetes deployments, install directly from the OCI registry.
|
||||
|
||||
**Step 1:** Clone the repository:
|
||||
**With ClusterIP + port-forward (development/testing):**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/RostislavDugin/postgresus.git
|
||||
cd postgresus
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace
|
||||
```
|
||||
|
||||
**Step 2:** Install with Helm:
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace
|
||||
kubectl port-forward svc/postgresus-service 4005:4005 -n postgresus
|
||||
# Access at http://localhost:4005
|
||||
```
|
||||
|
||||
**Step 3:** Get the external IP:
|
||||
**With LoadBalancer (cloud environments):**
|
||||
|
||||
```bash
|
||||
kubectl get svc -n postgresus
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
--set service.type=LoadBalancer
|
||||
```
|
||||
|
||||
Access Postgresus at `http://<EXTERNAL-IP>` (port 80).
|
||||
```bash
|
||||
kubectl get svc postgresus-service -n postgresus
|
||||
# Access at http://<EXTERNAL-IP>:4005
|
||||
```
|
||||
|
||||
To customize the installation (e.g., storage size, NodePort instead of LoadBalancer), see the [Helm chart README](deploy/helm/README.md) for all configuration options.
|
||||
**With Ingress (domain-based access):**
|
||||
|
||||
Config uses by default LoadBalancer, but has predefined values for Ingress and HTTPRoute as well.
|
||||
```bash
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
--set ingress.enabled=true \
|
||||
--set ingress.hosts[0].host=backup.example.com
|
||||
```
|
||||
|
||||
For more options (NodePort, TLS, HTTPRoute for Gateway API), see the [Helm chart README](deploy/helm/README.md).
|
||||
|
||||
---
|
||||
|
||||
@@ -192,13 +212,13 @@ Config uses by default LoadBalancer, but has predefined values for Ingress and H
|
||||
|
||||
1. **Access the dashboard**: Navigate to `http://localhost:4005`
|
||||
2. **Add first DB for backup**: Click "New Database" and follow the setup wizard
|
||||
3. **Configure schedule**: Choose from hourly, daily, weekly or monthly intervals
|
||||
3. **Configure schedule**: Choose from hourly, daily, weekly, monthly or cron intervals
|
||||
4. **Set database connection**: Enter your PostgreSQL credentials and connection details
|
||||
5. **Choose storage**: Select where to store your backups (local, S3, Google Drive, etc.)
|
||||
6. **Add notifications** (optional): Configure email, Telegram, Slack, or webhook notifications
|
||||
7. **Save and start**: Postgresus will validate settings and begin the backup schedule
|
||||
|
||||
### 🔑 Resetting Password <a href="https://postgresus.com/password">(docs)</a>
|
||||
### 🔑 Resetting password <a href="https://postgresus.com/password">(docs)</a>
|
||||
|
||||
If you need to reset the password, you can use the built-in password reset command:
|
||||
|
||||
@@ -212,10 +232,10 @@ Replace `admin` with the actual email address of the user whose password you wan
|
||||
|
||||
## 📝 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details.
|
||||
This project is licensed under the Apache 2.0 License - see the [LICENSE](LICENSE) file for details
|
||||
|
||||
---
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
Contributions are welcome! Read <a href="https://postgresus.com/contributing">contributing guide</a> for more details, prioerities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
Contributions are welcome! Read <a href="https://postgresus.com/contribute">contributing guide</a> for more details, priorities and rules are specified there. If you want to contribute, but don't know what and how - message me on Telegram [@rostislav_dugin](https://t.me/rostislav_dugin)
|
||||
|
||||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 537 KiB After Width: | Height: | Size: 766 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 913 KiB After Width: | Height: | Size: 771 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 34 KiB After Width: | Height: | Size: 13 KiB |
@@ -9,4 +9,4 @@ When applying changes, do not forget to refactor old code.
|
||||
You can shortify, make more readable, improve code quality, etc.
|
||||
Common logic can be extracted to functions, constants, files, etc.
|
||||
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder).
|
||||
After each large change with more than ~50-100 lines of code - always run `make lint` (from backend root folder) and, if you change frontend, run `npm run format` (from frontend root folder).
|
||||
|
||||
@@ -33,4 +33,30 @@ TEST_NAS_PORT=7006
|
||||
TEST_TELEGRAM_BOT_TOKEN=
|
||||
TEST_TELEGRAM_CHAT_ID=
|
||||
# testing Azure Blob Storage
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
TEST_AZURITE_BLOB_PORT=10000
|
||||
# supabase
|
||||
TEST_SUPABASE_HOST=
|
||||
TEST_SUPABASE_PORT=
|
||||
TEST_SUPABASE_USERNAME=
|
||||
TEST_SUPABASE_PASSWORD=
|
||||
TEST_SUPABASE_DATABASE=
|
||||
# FTP
|
||||
TEST_FTP_PORT=7007
|
||||
# SFTP
|
||||
TEST_SFTP_PORT=7008
|
||||
# MySQL Test Ports
|
||||
TEST_MYSQL_57_PORT=33057
|
||||
TEST_MYSQL_80_PORT=33080
|
||||
TEST_MYSQL_84_PORT=33084
|
||||
# testing MariaDB
|
||||
TEST_MARIADB_55_PORT=33055
|
||||
TEST_MARIADB_101_PORT=33101
|
||||
TEST_MARIADB_102_PORT=33102
|
||||
TEST_MARIADB_103_PORT=33103
|
||||
TEST_MARIADB_104_PORT=33104
|
||||
TEST_MARIADB_105_PORT=33105
|
||||
TEST_MARIADB_106_PORT=33106
|
||||
TEST_MARIADB_1011_PORT=33111
|
||||
TEST_MARIADB_114_PORT=33114
|
||||
TEST_MARIADB_118_PORT=33118
|
||||
TEST_MARIADB_120_PORT=33120
|
||||
2
backend/.gitignore
vendored
2
backend/.gitignore
vendored
@@ -3,6 +3,8 @@ main
|
||||
docker-compose.yml
|
||||
pgdata
|
||||
pgdata_test/
|
||||
mysqldata/
|
||||
mariadbdata/
|
||||
main.exe
|
||||
swagger/
|
||||
swagger/*
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
timeout: 1m
|
||||
timeout: 5m
|
||||
tests: false
|
||||
concurrency: 4
|
||||
|
||||
|
||||
@@ -31,14 +31,6 @@ services:
|
||||
container_name: test-minio
|
||||
command: server /data --console-address ":9001"
|
||||
|
||||
# Test Azurite container
|
||||
test-azurite:
|
||||
image: mcr.microsoft.com/azure-storage/azurite
|
||||
ports:
|
||||
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
|
||||
container_name: test-azurite
|
||||
command: azurite-blob --blobHost 0.0.0.0
|
||||
|
||||
# Test PostgreSQL containers
|
||||
test-postgres-12:
|
||||
image: postgres:12
|
||||
@@ -117,6 +109,14 @@ services:
|
||||
container_name: test-postgres-18
|
||||
shm_size: 1gb
|
||||
|
||||
# Test Azurite container
|
||||
test-azurite:
|
||||
image: mcr.microsoft.com/azure-storage/azurite
|
||||
ports:
|
||||
- "${TEST_AZURITE_BLOB_PORT:-10000}:10000"
|
||||
container_name: test-azurite
|
||||
command: azurite-blob --blobHost 0.0.0.0
|
||||
|
||||
# Test NAS server (Samba)
|
||||
test-nas:
|
||||
image: dperson/samba:latest
|
||||
@@ -132,3 +132,293 @@ services:
|
||||
-s "backups;/shared;yes;no;no;testuser"
|
||||
-p
|
||||
container_name: test-nas
|
||||
|
||||
# Test FTP server
|
||||
test-ftp:
|
||||
image: stilliard/pure-ftpd:latest
|
||||
ports:
|
||||
- "${TEST_FTP_PORT:-21}:21"
|
||||
- "30000-30009:30000-30009"
|
||||
environment:
|
||||
- PUBLICHOST=localhost
|
||||
- FTP_USER_NAME=testuser
|
||||
- FTP_USER_PASS=testpassword
|
||||
- FTP_USER_HOME=/home/ftpusers/testuser
|
||||
- FTP_PASSIVE_PORTS=30000:30009
|
||||
container_name: test-ftp
|
||||
|
||||
# Test SFTP server
|
||||
test-sftp:
|
||||
image: atmoz/sftp:latest
|
||||
ports:
|
||||
- "${TEST_SFTP_PORT:-7008}:22"
|
||||
command: testuser:testpassword:1001::upload
|
||||
container_name: test-sftp
|
||||
|
||||
# Test MySQL containers
|
||||
test-mysql-57:
|
||||
image: mysql:5.7
|
||||
ports:
|
||||
- "${TEST_MYSQL_57_PORT:-33057}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mysqldata/mysql-57:/var/lib/mysql
|
||||
container_name: test-mysql-57
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mysql-80:
|
||||
image: mysql:8.0
|
||||
ports:
|
||||
- "${TEST_MYSQL_80_PORT:-33080}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci --default-authentication-plugin=mysql_native_password
|
||||
volumes:
|
||||
- ./mysqldata/mysql-80:/var/lib/mysql
|
||||
container_name: test-mysql-80
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mysql-84:
|
||||
image: mysql:8.4
|
||||
ports:
|
||||
- "${TEST_MYSQL_84_PORT:-33084}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mysqldata/mysql-84:/var/lib/mysql
|
||||
container_name: test-mysql-84
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
# Test MariaDB containers
|
||||
test-mariadb-55:
|
||||
image: mariadb:5.5
|
||||
ports:
|
||||
- "${TEST_MARIADB_55_PORT:-33055}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8 --collation-server=utf8_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-55:/var/lib/mysql
|
||||
container_name: test-mariadb-55
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-101:
|
||||
image: mariadb:10.1
|
||||
ports:
|
||||
- "${TEST_MARIADB_101_PORT:-33101}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-101:/var/lib/mysql
|
||||
container_name: test-mariadb-101
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-102:
|
||||
image: mariadb:10.2
|
||||
ports:
|
||||
- "${TEST_MARIADB_102_PORT:-33102}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-102:/var/lib/mysql
|
||||
container_name: test-mariadb-102
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-103:
|
||||
image: mariadb:10.3
|
||||
ports:
|
||||
- "${TEST_MARIADB_103_PORT:-33103}:3306"
|
||||
environment:
|
||||
- MYSQL_ROOT_PASSWORD=rootpassword
|
||||
- MYSQL_DATABASE=testdb
|
||||
- MYSQL_USER=testuser
|
||||
- MYSQL_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-103:/var/lib/mysql
|
||||
container_name: test-mariadb-103
|
||||
healthcheck:
|
||||
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-prootpassword"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-104:
|
||||
image: mariadb:10.4
|
||||
ports:
|
||||
- "${TEST_MARIADB_104_PORT:-33104}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-104:/var/lib/mysql
|
||||
container_name: test-mariadb-104
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-105:
|
||||
image: mariadb:10.5
|
||||
ports:
|
||||
- "${TEST_MARIADB_105_PORT:-33105}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-105:/var/lib/mysql
|
||||
container_name: test-mariadb-105
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-106:
|
||||
image: mariadb:10.6
|
||||
ports:
|
||||
- "${TEST_MARIADB_106_PORT:-33106}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-106:/var/lib/mysql
|
||||
container_name: test-mariadb-106
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-1011:
|
||||
image: mariadb:10.11
|
||||
ports:
|
||||
- "${TEST_MARIADB_1011_PORT:-33111}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-1011:/var/lib/mysql
|
||||
container_name: test-mariadb-1011
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-114:
|
||||
image: mariadb:11.4
|
||||
ports:
|
||||
- "${TEST_MARIADB_114_PORT:-33114}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-114:/var/lib/mysql
|
||||
container_name: test-mariadb-114
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-118:
|
||||
image: mariadb:11.8
|
||||
ports:
|
||||
- "${TEST_MARIADB_118_PORT:-33118}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-118:/var/lib/mysql
|
||||
container_name: test-mariadb-118
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
test-mariadb-120:
|
||||
image: mariadb:12.0
|
||||
ports:
|
||||
- "${TEST_MARIADB_120_PORT:-33120}:3306"
|
||||
environment:
|
||||
- MARIADB_ROOT_PASSWORD=rootpassword
|
||||
- MARIADB_DATABASE=testdb
|
||||
- MARIADB_USER=testuser
|
||||
- MARIADB_PASSWORD=testpassword
|
||||
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
|
||||
volumes:
|
||||
- ./mariadbdata/mariadb-120:/var/lib/mysql
|
||||
container_name: test-mariadb-120
|
||||
healthcheck:
|
||||
test: ["CMD", "healthcheck.sh", "--connect", "--innodb_initialized"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 10
|
||||
|
||||
241
backend/go.mod
241
backend/go.mod
@@ -1,6 +1,6 @@
|
||||
module postgresus-backend
|
||||
|
||||
go 1.23.3
|
||||
go 1.24.4
|
||||
|
||||
require (
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0
|
||||
@@ -12,35 +12,197 @@ require (
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/ilyakaznacheev/cleanenv v1.5.0
|
||||
github.com/jackc/pgx/v5 v5.7.5
|
||||
github.com/jlaffaye/ftp v0.2.1-0.20240918233326-1b970516f5d3
|
||||
github.com/jmoiron/sqlx v1.4.0
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/minio/minio-go/v7 v7.0.92
|
||||
github.com/shirou/gopsutil/v4 v4.25.5
|
||||
github.com/minio/minio-go/v7 v7.0.97
|
||||
github.com/pkg/sftp v1.13.10
|
||||
github.com/rclone/rclone v1.72.1
|
||||
github.com/robfig/cron/v3 v3.0.1
|
||||
github.com/shirou/gopsutil/v4 v4.25.10
|
||||
github.com/stretchr/testify v1.11.1
|
||||
github.com/swaggo/files v1.0.1
|
||||
github.com/swaggo/gin-swagger v1.6.0
|
||||
github.com/swaggo/swag v1.16.4
|
||||
golang.org/x/crypto v0.41.0
|
||||
golang.org/x/time v0.12.0
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/time v0.14.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.26.1
|
||||
)
|
||||
|
||||
require github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
require (
|
||||
filippo.io/edwards25519 v1.1.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.0 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 // indirect
|
||||
github.com/Azure/azure-sdk-for-go/sdk/storage/azfile v1.5.3 // indirect
|
||||
github.com/Azure/go-ntlmssp v0.0.2-0.20251110135918-10b7b7e7cd26 // indirect
|
||||
github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect
|
||||
github.com/Files-com/files-sdk-go/v3 v3.2.264 // indirect
|
||||
github.com/IBM/go-sdk-core/v5 v5.21.0 // indirect
|
||||
github.com/Max-Sum/base32768 v0.0.0-20230304063302-18e6ce5945fd // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/bcrypt v0.0.0-20211005172633-e235017c1baf // indirect
|
||||
github.com/ProtonMail/gluon v0.17.1-0.20230724134000-308be39be96e // indirect
|
||||
github.com/ProtonMail/go-crypto v1.3.0 // indirect
|
||||
github.com/ProtonMail/go-mime v0.0.0-20230322103455-7d82a3887f2f // indirect
|
||||
github.com/ProtonMail/go-srp v0.0.7 // indirect
|
||||
github.com/ProtonMail/gopenpgp/v2 v2.9.0 // indirect
|
||||
github.com/PuerkitoBio/goquery v1.10.3 // indirect
|
||||
github.com/a1ex3/zstd-seekable-format-go/pkg v0.10.0 // indirect
|
||||
github.com/abbot/go-http-auth v0.4.0 // indirect
|
||||
github.com/anchore/go-lzo v0.1.0 // indirect
|
||||
github.com/andybalholm/cascadia v1.3.3 // indirect
|
||||
github.com/appscode/go-querystring v0.0.0-20170504095604-0126cfb3f1dc // indirect
|
||||
github.com/aws/aws-sdk-go-v2 v1.39.6 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/config v1.31.17 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.21 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.20.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.4 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.13 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.90.0 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 // indirect
|
||||
github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 // indirect
|
||||
github.com/aws/smithy-go v1.23.2 // indirect
|
||||
github.com/bahlo/generic-list-go v0.2.0 // indirect
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/boombuler/barcode v1.1.0 // indirect
|
||||
github.com/bradenaw/juniper v0.15.3 // indirect
|
||||
github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 // indirect
|
||||
github.com/buengese/sgzip v0.1.1 // indirect
|
||||
github.com/buger/jsonparser v1.1.1 // indirect
|
||||
github.com/calebcase/tmpfile v1.0.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/chilts/sid v0.0.0-20190607042430-660e94789ec9 // indirect
|
||||
github.com/clipperhouse/stringish v0.1.1 // indirect
|
||||
github.com/clipperhouse/uax29/v2 v2.3.0 // indirect
|
||||
github.com/cloudflare/circl v1.6.1 // indirect
|
||||
github.com/cloudinary/cloudinary-go/v2 v2.13.0 // indirect
|
||||
github.com/cloudsoda/go-smb2 v0.0.0-20250228001242-d4c70e6251cc // indirect
|
||||
github.com/cloudsoda/sddl v0.0.0-20250224235906-926454e91efc // indirect
|
||||
github.com/colinmarc/hdfs/v2 v2.4.0 // indirect
|
||||
github.com/coreos/go-semver v0.3.1 // indirect
|
||||
github.com/coreos/go-systemd/v22 v22.6.0 // indirect
|
||||
github.com/creasty/defaults v1.8.0 // indirect
|
||||
github.com/cronokirby/saferith v0.33.0 // indirect
|
||||
github.com/diskfs/go-diskfs v1.7.0 // indirect
|
||||
github.com/dropbox/dropbox-sdk-go-unofficial/v6 v6.0.5 // indirect
|
||||
github.com/emersion/go-message v0.18.2 // indirect
|
||||
github.com/emersion/go-vcard v0.0.0-20241024213814-c9703dde27ff // indirect
|
||||
github.com/flynn/noise v1.1.0 // indirect
|
||||
github.com/go-chi/chi/v5 v5.2.3 // indirect
|
||||
github.com/go-darwin/apfs v0.0.0-20211011131704-f84b94dbf348 // indirect
|
||||
github.com/go-git/go-billy/v5 v5.6.2 // indirect
|
||||
github.com/go-openapi/errors v0.22.4 // indirect
|
||||
github.com/go-openapi/strfmt v0.25.0 // indirect
|
||||
github.com/go-resty/resty/v2 v2.16.5 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
|
||||
github.com/gofrs/flock v0.13.0 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/gorilla/schema v1.4.1 // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-cleanhttp v0.5.2 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/hashicorp/go-retryablehttp v0.7.8 // indirect
|
||||
github.com/hashicorp/go-uuid v1.0.3 // indirect
|
||||
github.com/henrybear327/Proton-API-Bridge v1.0.0 // indirect
|
||||
github.com/henrybear327/go-proton-api v1.0.0 // indirect
|
||||
github.com/jcmturner/aescts/v2 v2.0.0 // indirect
|
||||
github.com/jcmturner/dnsutils/v2 v2.0.0 // indirect
|
||||
github.com/jcmturner/gofork v1.7.6 // indirect
|
||||
github.com/jcmturner/goidentity/v6 v6.0.1 // indirect
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.4 // indirect
|
||||
github.com/jcmturner/rpc/v2 v2.0.3 // indirect
|
||||
github.com/jtolio/noiseconn v0.0.0-20231127013910-f6d9ecbf1de7 // indirect
|
||||
github.com/jzelinskie/whirlpool v0.0.0-20201016144138-0675e54bb004 // indirect
|
||||
github.com/klauspost/crc32 v1.3.0 // indirect
|
||||
github.com/koofr/go-httpclient v0.0.0-20240520111329-e20f8f203988 // indirect
|
||||
github.com/koofr/go-koofrclient v0.0.0-20221207135200-cbd7fc9ad6a6 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/kylelemons/godebug v1.1.0 // indirect
|
||||
github.com/lanrat/extsort v1.4.2 // indirect
|
||||
github.com/lpar/date v1.0.0 // indirect
|
||||
github.com/lufia/plan9stats v0.0.0-20251013123823-9fd1530e3ec3 // indirect
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.19 // indirect
|
||||
github.com/mitchellh/go-homedir v1.1.0 // indirect
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/ncw/swift/v2 v2.0.5 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/oracle/oci-go-sdk/v65 v65.104.0 // indirect
|
||||
github.com/panjf2000/ants/v2 v2.11.3 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
|
||||
github.com/pengsrc/go-shared v0.2.1-0.20190131101655-1999055a4a14 // indirect
|
||||
github.com/peterh/liner v1.2.2 // indirect
|
||||
github.com/pierrec/lz4/v4 v4.1.22 // indirect
|
||||
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pkg/xattr v0.4.12 // indirect
|
||||
github.com/pquerna/otp v1.5.0 // indirect
|
||||
github.com/prometheus/client_golang v1.23.2 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.67.2 // indirect
|
||||
github.com/prometheus/procfs v0.19.2 // indirect
|
||||
github.com/putdotio/go-putio/putio v0.0.0-20200123120452-16d982cac2b8 // indirect
|
||||
github.com/relvacode/iso8601 v1.7.0 // indirect
|
||||
github.com/rfjakob/eme v1.1.2 // indirect
|
||||
github.com/sabhiram/go-gitignore v0.0.0-20210923224102-525f6e181f06 // indirect
|
||||
github.com/samber/lo v1.52.0 // indirect
|
||||
github.com/sirupsen/logrus v1.9.4-0.20230606125235-dd1b4c2e81af // indirect
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 // indirect
|
||||
github.com/sony/gobreaker v1.0.0 // indirect
|
||||
github.com/spacemonkeygo/monkit/v3 v3.0.25-0.20251022131615-eb24eb109368 // indirect
|
||||
github.com/spf13/pflag v1.0.10 // indirect
|
||||
github.com/t3rm1n4l/go-mega v0.0.0-20251031123324-a804aaa87491 // indirect
|
||||
github.com/tklauser/go-sysconf v0.3.15 // indirect
|
||||
github.com/tklauser/numcpus v0.10.0 // indirect
|
||||
github.com/ulikunitz/xz v0.5.15 // indirect
|
||||
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 // indirect
|
||||
github.com/yunify/qingstor-sdk-go/v3 v3.2.0 // indirect
|
||||
github.com/zeebo/blake3 v0.2.4 // indirect
|
||||
github.com/zeebo/errs v1.4.0 // indirect
|
||||
github.com/zeebo/xxh3 v1.0.2 // indirect
|
||||
go.etcd.io/bbolt v1.4.3 // indirect
|
||||
go.mongodb.org/mongo-driver v1.17.6 // indirect
|
||||
go.yaml.in/yaml/v2 v2.4.3 // indirect
|
||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
||||
golang.org/x/term v0.38.0 // indirect
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect
|
||||
gopkg.in/validator.v2 v2.0.1 // indirect
|
||||
moul.io/http2curl/v2 v2.3.0 // indirect
|
||||
sigs.k8s.io/yaml v1.6.0 // indirect
|
||||
storj.io/common v0.0.0-20251107171817-6221ae45072c // indirect
|
||||
storj.io/drpc v0.0.35-0.20250513201419-f7819ea69b55 // indirect
|
||||
storj.io/eventkit v0.0.0-20250410172343-61f26d3de156 // indirect
|
||||
storj.io/infectious v0.0.2 // indirect
|
||||
storj.io/picobuf v0.0.4 // indirect
|
||||
storj.io/uplink v1.13.1 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/auth v0.16.2 // indirect
|
||||
cloud.google.com/go/auth v0.17.0 // indirect
|
||||
cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.7.0 // indirect
|
||||
github.com/geoffgarside/ber v1.1.0 // indirect
|
||||
cloud.google.com/go/compute/metadata v0.9.0 // indirect
|
||||
github.com/geoffgarside/ber v1.2.0 // indirect
|
||||
github.com/google/s2a-go v0.1.9 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.6 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.14.2 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.15.0 // indirect
|
||||
github.com/hirochachacha/go-smb2 v1.1.0
|
||||
google.golang.org/genproto/googleapis/api v0.0.0-20250528174236-200df99c418a // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250603155806-513f23925822 // indirect
|
||||
google.golang.org/grpc v1.73.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect
|
||||
google.golang.org/grpc v1.76.0 // indirect
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -51,11 +213,11 @@ require (
|
||||
github.com/bytedance/sonic v1.13.2 // indirect
|
||||
github.com/bytedance/sonic/loader v0.2.4 // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/ebitengine/purego v0.8.4 // indirect
|
||||
github.com/ebitengine/purego v0.9.1 // indirect
|
||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.11 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-ini/ini v1.67.0 // indirect
|
||||
github.com/go-logr/logr v1.4.3 // indirect
|
||||
@@ -67,8 +229,8 @@ require (
|
||||
github.com/go-openapi/swag v0.19.15 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.26.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.2 // indirect
|
||||
github.com/go-playground/validator/v10 v10.28.0 // indirect
|
||||
github.com/go-sql-driver/mysql v1.9.2
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
|
||||
@@ -77,40 +239,39 @@ require (
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.18.0 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
|
||||
github.com/klauspost/compress v1.18.1
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.6 // indirect
|
||||
github.com/mailru/easyjson v0.9.1 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/minio/crc64nvme v1.0.1 // indirect
|
||||
github.com/minio/crc64nvme v1.1.1 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/philhofer/fwd v1.1.3-0.20240916144458-20a13a1f6b7c // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/philhofer/fwd v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
|
||||
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
|
||||
github.com/rogpeppe/go-internal v1.14.1 // indirect
|
||||
github.com/rs/xid v1.6.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tinylib/msgp v1.3.0 // indirect
|
||||
github.com/tinylib/msgp v1.5.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
github.com/yusufpapurcu/wmi v1.2.4 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.1.0 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect
|
||||
go.opentelemetry.io/otel v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.36.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.36.0 // indirect
|
||||
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
|
||||
go.opentelemetry.io/otel v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/metric v1.38.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.38.0 // indirect
|
||||
golang.org/x/arch v0.17.0 // indirect
|
||||
golang.org/x/net v0.43.0 // indirect
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0 // indirect
|
||||
golang.org/x/sys v0.35.0 // indirect
|
||||
golang.org/x/text v0.28.0 // indirect
|
||||
golang.org/x/tools v0.35.0 // indirect
|
||||
google.golang.org/api v0.239.0
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
golang.org/x/net v0.47.0 // indirect
|
||||
golang.org/x/oauth2 v0.33.0
|
||||
golang.org/x/sync v0.19.0 // indirect
|
||||
golang.org/x/sys v0.39.0 // indirect
|
||||
golang.org/x/text v0.32.0 // indirect
|
||||
golang.org/x/tools v0.39.0 // indirect
|
||||
google.golang.org/api v0.255.0
|
||||
google.golang.org/protobuf v1.36.10 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
olympos.io/encoding/edn v0.0.0-20201019073823-d3554ca0b0a3 // indirect
|
||||
|
||||
951
backend/go.sum
951
backend/go.sum
File diff suppressed because it is too large
Load Diff
@@ -25,6 +25,8 @@ type EnvVariables struct {
|
||||
DatabaseDsn string `env:"DATABASE_DSN" required:"true"`
|
||||
EnvMode env_utils.EnvMode `env:"ENV_MODE" required:"true"`
|
||||
PostgresesInstallDir string `env:"POSTGRES_INSTALL_DIR"`
|
||||
MysqlInstallDir string `env:"MYSQL_INSTALL_DIR"`
|
||||
MariadbInstallDir string `env:"MARIADB_INSTALL_DIR"`
|
||||
|
||||
DataFolder string
|
||||
TempFolder string
|
||||
@@ -47,7 +49,25 @@ type EnvVariables struct {
|
||||
|
||||
TestAzuriteBlobPort string `env:"TEST_AZURITE_BLOB_PORT"`
|
||||
|
||||
TestNASPort string `env:"TEST_NAS_PORT"`
|
||||
TestNASPort string `env:"TEST_NAS_PORT"`
|
||||
TestFTPPort string `env:"TEST_FTP_PORT"`
|
||||
TestSFTPPort string `env:"TEST_SFTP_PORT"`
|
||||
|
||||
TestMysql57Port string `env:"TEST_MYSQL_57_PORT"`
|
||||
TestMysql80Port string `env:"TEST_MYSQL_80_PORT"`
|
||||
TestMysql84Port string `env:"TEST_MYSQL_84_PORT"`
|
||||
|
||||
TestMariadb55Port string `env:"TEST_MARIADB_55_PORT"`
|
||||
TestMariadb101Port string `env:"TEST_MARIADB_101_PORT"`
|
||||
TestMariadb102Port string `env:"TEST_MARIADB_102_PORT"`
|
||||
TestMariadb103Port string `env:"TEST_MARIADB_103_PORT"`
|
||||
TestMariadb104Port string `env:"TEST_MARIADB_104_PORT"`
|
||||
TestMariadb105Port string `env:"TEST_MARIADB_105_PORT"`
|
||||
TestMariadb106Port string `env:"TEST_MARIADB_106_PORT"`
|
||||
TestMariadb1011Port string `env:"TEST_MARIADB_1011_PORT"`
|
||||
TestMariadb114Port string `env:"TEST_MARIADB_114_PORT"`
|
||||
TestMariadb118Port string `env:"TEST_MARIADB_118_PORT"`
|
||||
TestMariadb120Port string `env:"TEST_MARIADB_120_PORT"`
|
||||
|
||||
// oauth
|
||||
GitHubClientID string `env:"GITHUB_CLIENT_ID"`
|
||||
@@ -58,6 +78,13 @@ type EnvVariables struct {
|
||||
// testing Telegram
|
||||
TestTelegramBotToken string `env:"TEST_TELEGRAM_BOT_TOKEN"`
|
||||
TestTelegramChatID string `env:"TEST_TELEGRAM_CHAT_ID"`
|
||||
|
||||
// testing Supabase
|
||||
TestSupabaseHost string `env:"TEST_SUPABASE_HOST"`
|
||||
TestSupabasePort string `env:"TEST_SUPABASE_PORT"`
|
||||
TestSupabaseUsername string `env:"TEST_SUPABASE_USERNAME"`
|
||||
TestSupabasePassword string `env:"TEST_SUPABASE_PASSWORD"`
|
||||
TestSupabaseDatabase string `env:"TEST_SUPABASE_DATABASE"`
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -143,6 +170,12 @@ func loadEnvVariables() {
|
||||
env.PostgresesInstallDir = filepath.Join(backendRoot, "tools", "postgresql")
|
||||
tools.VerifyPostgresesInstallation(log, env.EnvMode, env.PostgresesInstallDir)
|
||||
|
||||
env.MysqlInstallDir = filepath.Join(backendRoot, "tools", "mysql")
|
||||
tools.VerifyMysqlInstallation(log, env.EnvMode, env.MysqlInstallDir)
|
||||
|
||||
env.MariadbInstallDir = filepath.Join(backendRoot, "tools", "mariadb")
|
||||
tools.VerifyMariadbInstallation(log, env.EnvMode, env.MariadbInstallDir)
|
||||
|
||||
// Store the data and temp folders one level below the root
|
||||
// (projectRoot/postgresus-data -> /postgresus-data)
|
||||
env.DataFolder = filepath.Join(filepath.Dir(backendRoot), "postgresus-data", "backups")
|
||||
|
||||
@@ -2,20 +2,21 @@ package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type BackupContextManager struct {
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
mu sync.RWMutex
|
||||
cancelFuncs map[uuid.UUID]context.CancelFunc
|
||||
cancelledBackups map[uuid.UUID]bool
|
||||
}
|
||||
|
||||
func NewBackupContextManager() *BackupContextManager {
|
||||
return &BackupContextManager{
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelFuncs: make(map[uuid.UUID]context.CancelFunc),
|
||||
cancelledBackups: make(map[uuid.UUID]bool),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,25 +24,37 @@ func (m *BackupContextManager) RegisterBackup(backupID uuid.UUID, cancelFunc con
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cancelFuncs[backupID] = cancelFunc
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) CancelBackup(backupID uuid.UUID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if !exists {
|
||||
return errors.New("backup is not in progress or already completed")
|
||||
if m.cancelledBackups[backupID] {
|
||||
return nil
|
||||
}
|
||||
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
cancelFunc, exists := m.cancelFuncs[backupID]
|
||||
if exists {
|
||||
cancelFunc()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
}
|
||||
|
||||
m.cancelledBackups[backupID] = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) IsCancelled(backupID uuid.UUID) bool {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.cancelledBackups[backupID]
|
||||
}
|
||||
|
||||
func (m *BackupContextManager) UnregisterBackup(backupID uuid.UUID) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.cancelFuncs, backupID)
|
||||
delete(m.cancelledBackups, backupID)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -181,7 +182,7 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
fileReader, err := c.backupService.GetBackupFile(user, id)
|
||||
fileReader, dbType, err := c.backupService.GetBackupFile(user, id)
|
||||
if err != nil {
|
||||
ctx.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -192,10 +193,15 @@ func (c *BackupController) GetFile(ctx *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
extension := ".dump.zst"
|
||||
if dbType == databases.DatabaseTypeMysql {
|
||||
extension = ".sql.zst"
|
||||
}
|
||||
|
||||
ctx.Header("Content-Type", "application/octet-stream")
|
||||
ctx.Header(
|
||||
"Content-Disposition",
|
||||
fmt.Sprintf("attachment; filename=\"backup_%s.dump\"", id.String()),
|
||||
fmt.Sprintf("attachment; filename=\"backup_%s%s\"", id.String(), extension),
|
||||
)
|
||||
|
||||
_, err = io.Copy(ctx.Writer, fileReader)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package backups
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -701,7 +702,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
if err := storages[0].SaveFile(context.Background(), encryption.GetFieldEncryptor(), logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ const (
|
||||
NonceLen = 12
|
||||
ReservedLen = 12
|
||||
HeaderLen = MagicBytesLen + SaltLen + NonceLen + ReservedLen
|
||||
ChunkSize = 32 * 1024
|
||||
ChunkSize = 1 * 1024 * 1024
|
||||
PBKDF2Iterations = 100000
|
||||
)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package backups
|
||||
import (
|
||||
"context"
|
||||
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
usecases_common "postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
@@ -27,10 +27,8 @@ type CreateBackupUsecase interface {
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) (*usecases_postgresql.BackupMetadata, error)
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*usecases_common.BackupMetadata, error)
|
||||
}
|
||||
|
||||
type BackupRemoveListener interface {
|
||||
|
||||
@@ -275,7 +275,12 @@ func (s *BackupService) MakeBackup(databaseID uuid.UUID, isLastTry bool) {
|
||||
errMsg := err.Error()
|
||||
|
||||
// Check if backup was cancelled (not due to shutdown)
|
||||
if strings.Contains(errMsg, "backup cancelled") && !strings.Contains(errMsg, "shutdown") {
|
||||
isCancelled := strings.Contains(errMsg, "backup cancelled") ||
|
||||
strings.Contains(errMsg, "context canceled") ||
|
||||
errors.Is(err, context.Canceled)
|
||||
isShutdown := strings.Contains(errMsg, "shutdown")
|
||||
|
||||
if isCancelled && !isShutdown {
|
||||
backup.Status = BackupStatusCanceled
|
||||
backup.BackupDurationMs = time.Since(start).Milliseconds()
|
||||
backup.BackupSizeMb = 0
|
||||
@@ -497,19 +502,19 @@ func (s *BackupService) CancelBackup(
|
||||
func (s *BackupService) GetBackupFile(
|
||||
user *users_models.User,
|
||||
backupID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
) (io.ReadCloser, databases.DatabaseType, error) {
|
||||
backup, err := s.backupRepository.FindByID(backupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
database, err := s.databaseService.GetDatabaseByID(backup.DatabaseID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if database.WorkspaceID == nil {
|
||||
return nil, errors.New("cannot download backup for database without workspace")
|
||||
return nil, "", errors.New("cannot download backup for database without workspace")
|
||||
}
|
||||
|
||||
canAccess, _, err := s.workspaceService.CanUserAccessWorkspace(
|
||||
@@ -517,10 +522,10 @@ func (s *BackupService) GetBackupFile(
|
||||
user,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
if !canAccess {
|
||||
return nil, errors.New("insufficient permissions to download backup for this database")
|
||||
return nil, "", errors.New("insufficient permissions to download backup for this database")
|
||||
}
|
||||
|
||||
s.auditLogService.WriteAuditLog(
|
||||
@@ -533,7 +538,12 @@ func (s *BackupService) GetBackupFile(
|
||||
database.WorkspaceID,
|
||||
)
|
||||
|
||||
return s.getBackupReader(backupID)
|
||||
reader, err := s.getBackupReader(backupID)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
return reader, database.Type, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteBackup(backup *Backup) error {
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
"postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
@@ -178,16 +178,13 @@ func (uc *CreateFailedBackupUsecase) Execute(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return nil, errors.New("backup failed")
|
||||
}
|
||||
|
||||
type CreateSuccessBackupUsecase struct {
|
||||
}
|
||||
type CreateSuccessBackupUsecase struct{}
|
||||
|
||||
func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
@@ -195,12 +192,10 @@ func (uc *CreateSuccessBackupUsecase) Execute(
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
backupProgressListener(10) // Assume we completed 10MB
|
||||
return &usecases_postgresql.BackupMetadata{
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*common.BackupMetadata, error) {
|
||||
backupProgressListener(10)
|
||||
return &common.BackupMetadata{
|
||||
EncryptionSalt: nil,
|
||||
EncryptionIV: nil,
|
||||
Encryption: backups_config.BackupEncryptionNone,
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
package usecases_postgresql
|
||||
package common
|
||||
|
||||
import backups_config "postgresus-backend/internal/features/backups/config"
|
||||
|
||||
type EncryptionMetadata struct {
|
||||
Salt string
|
||||
IV string
|
||||
Encryption backups_config.BackupEncryption
|
||||
}
|
||||
|
||||
type BackupMetadata struct {
|
||||
EncryptionSalt *string
|
||||
EncryptionIV *string
|
||||
@@ -0,0 +1,22 @@
|
||||
package common
|
||||
|
||||
import "io"
|
||||
|
||||
type CountingWriter struct {
|
||||
Writer io.Writer
|
||||
BytesWritten int64
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = cw.Writer.Write(p)
|
||||
cw.BytesWritten += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) GetBytesWritten() int64 {
|
||||
return cw.BytesWritten
|
||||
}
|
||||
|
||||
func NewCountingWriter(writer io.Writer) *CountingWriter {
|
||||
return &CountingWriter{Writer: writer}
|
||||
}
|
||||
@@ -3,6 +3,10 @@ package usecases
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
usecases_common "postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
usecases_mariadb "postgresus-backend/internal/features/backups/backups/usecases/mariadb"
|
||||
usecases_mysql "postgresus-backend/internal/features/backups/backups/usecases/mysql"
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
@@ -13,20 +17,20 @@ import (
|
||||
|
||||
type CreateBackupUsecase struct {
|
||||
CreatePostgresqlBackupUsecase *usecases_postgresql.CreatePostgresqlBackupUsecase
|
||||
CreateMysqlBackupUsecase *usecases_mysql.CreateMysqlBackupUsecase
|
||||
CreateMariadbBackupUsecase *usecases_mariadb.CreateMariadbBackupUsecase
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database and returns the backup metadata
|
||||
func (uc *CreateBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
database *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) (*usecases_postgresql.BackupMetadata, error) {
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
switch database.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.CreatePostgresqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
@@ -35,7 +39,28 @@ func (uc *CreateBackupUsecase) Execute(
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
}
|
||||
|
||||
return nil, errors.New("database type not supported")
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.CreateMysqlBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.CreateMariadbBackupUsecase.Execute(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
database,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
)
|
||||
|
||||
default:
|
||||
return nil, errors.New("database type not supported")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
usecases_mariadb "postgresus-backend/internal/features/backups/backups/usecases/mariadb"
|
||||
usecases_mysql "postgresus-backend/internal/features/backups/backups/usecases/mysql"
|
||||
usecases_postgresql "postgresus-backend/internal/features/backups/backups/usecases/postgresql"
|
||||
)
|
||||
|
||||
var createBackupUsecase = &CreateBackupUsecase{
|
||||
usecases_postgresql.GetCreatePostgresqlBackupUsecase(),
|
||||
usecases_mysql.GetCreateMysqlBackupUsecase(),
|
||||
usecases_mariadb.GetCreateMariadbBackupUsecase(),
|
||||
}
|
||||
|
||||
func GetCreateBackupUsecase() *CreateBackupUsecase {
|
||||
|
||||
@@ -0,0 +1,595 @@
|
||||
package usecases_mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
usecases_common "postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mariadbtypes "postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
backupTimeout = 23 * time.Hour
|
||||
shutdownCheckInterval = 1 * time.Second
|
||||
copyBufferSize = 8 * 1024 * 1024
|
||||
progressReportIntervalMB = 1.0
|
||||
zstdStorageCompressionLevel = 3
|
||||
exitCodeGenericError = 1
|
||||
exitCodeConnectionError = 2
|
||||
)
|
||||
|
||||
type CreateMariadbBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
type writeResult struct {
|
||||
bytesWritten int
|
||||
writeErr error
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info(
|
||||
"Creating MariaDB backup via mariadb-dump",
|
||||
"databaseId", db.ID,
|
||||
"storageId", storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
mdb := db.Mariadb
|
||||
if mdb == nil {
|
||||
return nil, fmt.Errorf("mariadb database configuration is required")
|
||||
}
|
||||
|
||||
if mdb.Database == nil || *mdb.Database == "" {
|
||||
return nil, fmt.Errorf("database name is required for mariadb-dump backups")
|
||||
}
|
||||
|
||||
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, mdb.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
|
||||
}
|
||||
|
||||
args := uc.buildMariadbDumpArgs(mdb)
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadbDump,
|
||||
mdb.Version,
|
||||
config.GetEnv().EnvMode,
|
||||
config.GetEnv().MariadbInstallDir,
|
||||
),
|
||||
args,
|
||||
decryptedPassword,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
mdb,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) buildMariadbDumpArgs(
|
||||
mdb *mariadbtypes.MariadbDatabase,
|
||||
) []string {
|
||||
args := []string{
|
||||
"--host=" + mdb.Host,
|
||||
"--port=" + strconv.Itoa(mdb.Port),
|
||||
"--user=" + mdb.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
args = append(args, "--compress")
|
||||
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
args = append(args, *mdb.Database)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
password string,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info("Streaming MariaDB backup to storage", "mariadbBin", mariadbBin)
|
||||
|
||||
ctx, cancel := uc.createBackupContext(parentCtx)
|
||||
defer cancel()
|
||||
|
||||
myCnfFile, err := uc.createTempMyCnfFile(mdbConfig, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create .my.cnf: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, mariadbBin, fullArgs...)
|
||||
uc.logger.Info("Executing MariaDB backup command", "command", cmd.String())
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env,
|
||||
"MYSQL_PWD=",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
)
|
||||
|
||||
pgStdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
pgStderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrCh := make(chan []byte, 1)
|
||||
go func() {
|
||||
stderrOutput, _ := io.ReadAll(pgStderr)
|
||||
stderrCh <- stderrOutput
|
||||
}()
|
||||
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zstdWriter, err := zstd.NewWriter(finalWriter,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(zstdStorageCompressionLevel)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create zstd writer: %w", err)
|
||||
}
|
||||
countingWriter := usecases_common.NewCountingWriter(zstdWriter)
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start %s: %w", filepath.Base(mariadbBin), err)
|
||||
}
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
bytesWrittenCh := make(chan int64, 1)
|
||||
go func() {
|
||||
bytesWritten, err := uc.copyWithShutdownCheck(
|
||||
ctx,
|
||||
countingWriter,
|
||||
pgStdout,
|
||||
backupProgressListener,
|
||||
)
|
||||
bytesWrittenCh <- bytesWritten
|
||||
copyResultCh <- err
|
||||
}()
|
||||
|
||||
copyErr := <-copyResultCh
|
||||
bytesWritten := <-bytesWrittenCh
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
uc.cleanupOnCancellation(zstdWriter, encryptionWriter, storageWriter, saveErrCh)
|
||||
return nil, uc.checkCancellationReason()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := zstdWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close zstd writer", "error", err)
|
||||
}
|
||||
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
|
||||
<-saveErrCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
saveErr := <-saveErrCh
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
if waitErr == nil && copyErr == nil && saveErr == nil && backupProgressListener != nil {
|
||||
sizeMB := float64(bytesWritten) / (1024 * 1024)
|
||||
backupProgressListener(sizeMB)
|
||||
}
|
||||
|
||||
switch {
|
||||
case waitErr != nil:
|
||||
return nil, uc.buildMariadbDumpErrorMessage(waitErr, stderrOutput, mariadbBin)
|
||||
case copyErr != nil:
|
||||
return nil, fmt.Errorf("copy to storage: %w", copyErr)
|
||||
case saveErr != nil:
|
||||
return nil, fmt.Errorf("save to storage: %w", saveErr)
|
||||
}
|
||||
|
||||
return &backupMetadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp("", "mycnf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
user=%s
|
||||
password="%s"
|
||||
host=%s
|
||||
port=%d
|
||||
`, mdbConfig.Username, tools.EscapeMariadbPassword(password), mdbConfig.Host, mdbConfig.Port)
|
||||
|
||||
if mdbConfig.IsHttps {
|
||||
content += "ssl=true\n"
|
||||
} else {
|
||||
content += "ssl=false\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
return myCnfFile, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (int64, error) {
|
||||
buf := make([]byte, copyBufferSize)
|
||||
var totalBytesWritten int64
|
||||
var lastReportedMB float64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
|
||||
}
|
||||
|
||||
bytesRead, readErr := src.Read(buf)
|
||||
if bytesRead > 0 {
|
||||
writeResultCh := make(chan writeResult, 1)
|
||||
go func() {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
writeResultCh <- writeResult{bytesWritten, writeErr}
|
||||
}()
|
||||
|
||||
var bytesWritten int
|
||||
var writeErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled during write: %w", ctx.Err())
|
||||
case result := <-writeResultCh:
|
||||
bytesWritten = result.bytesWritten
|
||||
writeErr = result.writeErr
|
||||
}
|
||||
|
||||
if bytesWritten < 0 || bytesRead < bytesWritten {
|
||||
bytesWritten = 0
|
||||
if writeErr == nil {
|
||||
writeErr = fmt.Errorf("invalid write result")
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return totalBytesWritten, writeErr
|
||||
}
|
||||
|
||||
if bytesRead != bytesWritten {
|
||||
return totalBytesWritten, io.ErrShortWrite
|
||||
}
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
|
||||
if backupProgressListener != nil {
|
||||
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
|
||||
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
|
||||
backupProgressListener(currentSizeMB)
|
||||
lastReportedMB = currentSizeMB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) createBackupContext(
|
||||
parentCtx context.Context,
|
||||
) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(shutdownCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
|
||||
metadata := usecases_common.BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
|
||||
return storageWriter, nil, metadata, nil
|
||||
}
|
||||
|
||||
salt, err := backup_encryption.GenerateSalt()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := backup_encryption.GenerateNonce()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
||||
storageWriter,
|
||||
masterKey,
|
||||
backupID,
|
||||
salt,
|
||||
nonce,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
||||
}
|
||||
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
metadata.EncryptionSalt = &saltBase64
|
||||
metadata.EncryptionIV = &nonceBase64
|
||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||
|
||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||
return encWriter, encWriter, metadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) cleanupOnCancellation(
|
||||
zstdWriter *zstd.Encoder,
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
saveErrCh chan error,
|
||||
) {
|
||||
if zstdWriter != nil {
|
||||
go func() {
|
||||
if closeErr := zstdWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close zstd writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
if closeErr := encryptionWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close encrypting writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
|
||||
}
|
||||
|
||||
<-saveErrCh
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) closeWriters(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
) error {
|
||||
encryptionCloseErrCh := make(chan error, 1)
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
closeErr := encryptionWriter.Close()
|
||||
if closeErr != nil {
|
||||
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
|
||||
}
|
||||
encryptionCloseErrCh <- closeErr
|
||||
}()
|
||||
} else {
|
||||
encryptionCloseErrCh <- nil
|
||||
}
|
||||
|
||||
encryptionCloseErr := <-encryptionCloseErrCh
|
||||
if encryptionCloseErr != nil {
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
|
||||
}
|
||||
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) checkCancellationReason() error {
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) buildMariadbDumpErrorMessage(
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
mariadbBin string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(mariadbBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
|
||||
return uc.handleConnectionErrors(stderrStr)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func (uc *CreateMariadbBackupUsecase) handleConnectionErrors(stderrStr string) error {
|
||||
if containsIgnoreCase(stderrStr, "access denied") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB access denied. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "can't connect") ||
|
||||
containsIgnoreCase(stderrStr, "connection refused") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB connection refused. Check if the server is running and accessible. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "unknown database") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB database does not exist. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB SSL connection failed. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB connection timeout. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("MariaDB connection or authentication error. stderr: %s", stderrStr)
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package usecases_mariadb
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var createMariadbBackupUsecase = &CreateMariadbBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
func GetCreateMariadbBackupUsecase() *CreateMariadbBackupUsecase {
|
||||
return createMariadbBackupUsecase
|
||||
}
|
||||
@@ -0,0 +1,608 @@
|
||||
package usecases_mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
usecases_common "postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mysqltypes "postgresus-backend/internal/features/databases/databases/mysql"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
const (
|
||||
backupTimeout = 23 * time.Hour
|
||||
shutdownCheckInterval = 1 * time.Second
|
||||
copyBufferSize = 8 * 1024 * 1024
|
||||
progressReportIntervalMB = 1.0
|
||||
zstdStorageCompressionLevel = 3
|
||||
exitCodeGenericError = 1
|
||||
exitCodeConnectionError = 2
|
||||
)
|
||||
|
||||
type CreateMysqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
type writeResult struct {
|
||||
bytesWritten int
|
||||
writeErr error
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
db *databases.Database,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info(
|
||||
"Creating MySQL backup via mysqldump",
|
||||
"databaseId", db.ID,
|
||||
"storageId", storage.ID,
|
||||
)
|
||||
|
||||
if !backupConfig.IsBackupsEnabled {
|
||||
return nil, fmt.Errorf("backups are not enabled for this database: \"%s\"", db.Name)
|
||||
}
|
||||
|
||||
my := db.Mysql
|
||||
if my == nil {
|
||||
return nil, fmt.Errorf("mysql database configuration is required")
|
||||
}
|
||||
|
||||
if my.Database == nil || *my.Database == "" {
|
||||
return nil, fmt.Errorf("database name is required for mysqldump backups")
|
||||
}
|
||||
|
||||
decryptedPassword, err := uc.fieldEncryptor.Decrypt(db.ID, my.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt database password: %w", err)
|
||||
}
|
||||
|
||||
args := uc.buildMysqldumpArgs(my)
|
||||
|
||||
return uc.streamToStorage(
|
||||
ctx,
|
||||
backupID,
|
||||
backupConfig,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
tools.MysqlExecutableMysqldump,
|
||||
config.GetEnv().EnvMode,
|
||||
config.GetEnv().MysqlInstallDir,
|
||||
),
|
||||
args,
|
||||
decryptedPassword,
|
||||
storage,
|
||||
backupProgressListener,
|
||||
my,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) buildMysqldumpArgs(my *mysqltypes.MysqlDatabase) []string {
|
||||
args := []string{
|
||||
"--host=" + my.Host,
|
||||
"--port=" + strconv.Itoa(my.Port),
|
||||
"--user=" + my.Username,
|
||||
"--single-transaction",
|
||||
"--routines",
|
||||
"--triggers",
|
||||
"--events",
|
||||
"--set-gtid-purged=OFF",
|
||||
"--quick",
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
args = append(args, uc.getNetworkCompressionArgs(my.Version)...)
|
||||
|
||||
if my.IsHttps {
|
||||
args = append(args, "--ssl-mode=REQUIRED")
|
||||
}
|
||||
|
||||
if my.Database != nil && *my.Database != "" {
|
||||
args = append(args, *my.Database)
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) getNetworkCompressionArgs(version tools.MysqlVersion) []string {
|
||||
const zstdCompressionLevel = 3
|
||||
|
||||
switch version {
|
||||
case tools.MysqlVersion80, tools.MysqlVersion84:
|
||||
return []string{
|
||||
"--compression-algorithms=zstd",
|
||||
fmt.Sprintf("--zstd-compression-level=%d", zstdCompressionLevel),
|
||||
}
|
||||
case tools.MysqlVersion57:
|
||||
return []string{"--compress"}
|
||||
default:
|
||||
return []string{"--compress"}
|
||||
}
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) streamToStorage(
|
||||
parentCtx context.Context,
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
password string,
|
||||
storage *storages.Storage,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info("Streaming MySQL backup to storage", "mysqlBin", mysqlBin)
|
||||
|
||||
ctx, cancel := uc.createBackupContext(parentCtx)
|
||||
defer cancel()
|
||||
|
||||
myCnfFile, err := uc.createTempMyCnfFile(myConfig, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create .my.cnf: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, mysqlBin, fullArgs...)
|
||||
uc.logger.Info("Executing MySQL backup command", "command", cmd.String())
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env,
|
||||
"MYSQL_PWD=",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
)
|
||||
|
||||
pgStdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
pgStderr, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrCh := make(chan []byte, 1)
|
||||
go func() {
|
||||
stderrOutput, _ := io.ReadAll(pgStderr)
|
||||
stderrCh <- stderrOutput
|
||||
}()
|
||||
|
||||
storageReader, storageWriter := io.Pipe()
|
||||
|
||||
finalWriter, encryptionWriter, backupMetadata, err := uc.setupBackupEncryption(
|
||||
backupID,
|
||||
backupConfig,
|
||||
storageWriter,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zstdWriter, err := zstd.NewWriter(finalWriter,
|
||||
zstd.WithEncoderLevel(zstd.EncoderLevelFromZstd(zstdStorageCompressionLevel)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create zstd writer: %w", err)
|
||||
}
|
||||
countingWriter := usecases_common.NewCountingWriter(zstdWriter)
|
||||
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start %s: %w", filepath.Base(mysqlBin), err)
|
||||
}
|
||||
|
||||
copyResultCh := make(chan error, 1)
|
||||
bytesWrittenCh := make(chan int64, 1)
|
||||
go func() {
|
||||
bytesWritten, err := uc.copyWithShutdownCheck(
|
||||
ctx,
|
||||
countingWriter,
|
||||
pgStdout,
|
||||
backupProgressListener,
|
||||
)
|
||||
bytesWrittenCh <- bytesWritten
|
||||
copyResultCh <- err
|
||||
}()
|
||||
|
||||
copyErr := <-copyResultCh
|
||||
bytesWritten := <-bytesWrittenCh
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
uc.cleanupOnCancellation(zstdWriter, encryptionWriter, storageWriter, saveErrCh)
|
||||
return nil, uc.checkCancellationReason()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := zstdWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close zstd writer", "error", err)
|
||||
}
|
||||
if err := uc.closeWriters(encryptionWriter, storageWriter); err != nil {
|
||||
<-saveErrCh
|
||||
return nil, err
|
||||
}
|
||||
|
||||
saveErr := <-saveErrCh
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
if waitErr == nil && copyErr == nil && saveErr == nil && backupProgressListener != nil {
|
||||
sizeMB := float64(bytesWritten) / (1024 * 1024)
|
||||
backupProgressListener(sizeMB)
|
||||
}
|
||||
|
||||
switch {
|
||||
case waitErr != nil:
|
||||
return nil, uc.buildMysqldumpErrorMessage(waitErr, stderrOutput, mysqlBin)
|
||||
case copyErr != nil:
|
||||
return nil, fmt.Errorf("copy to storage: %w", copyErr)
|
||||
case saveErr != nil:
|
||||
return nil, fmt.Errorf("save to storage: %w", saveErr)
|
||||
}
|
||||
|
||||
return &backupMetadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp("", "mycnf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
user=%s
|
||||
password="%s"
|
||||
host=%s
|
||||
port=%d
|
||||
`, myConfig.Username, tools.EscapeMysqlPassword(password), myConfig.Host, myConfig.Port)
|
||||
|
||||
if myConfig.IsHttps {
|
||||
content += "ssl-mode=REQUIRED\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
return myCnfFile, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (int64, error) {
|
||||
buf := make([]byte, copyBufferSize)
|
||||
var totalBytesWritten int64
|
||||
var lastReportedMB float64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
|
||||
}
|
||||
|
||||
bytesRead, readErr := src.Read(buf)
|
||||
if bytesRead > 0 {
|
||||
writeResultCh := make(chan writeResult, 1)
|
||||
go func() {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
writeResultCh <- writeResult{bytesWritten, writeErr}
|
||||
}()
|
||||
|
||||
var bytesWritten int
|
||||
var writeErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled during write: %w", ctx.Err())
|
||||
case result := <-writeResultCh:
|
||||
bytesWritten = result.bytesWritten
|
||||
writeErr = result.writeErr
|
||||
}
|
||||
|
||||
if bytesWritten < 0 || bytesRead < bytesWritten {
|
||||
bytesWritten = 0
|
||||
if writeErr == nil {
|
||||
writeErr = fmt.Errorf("invalid write result")
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return totalBytesWritten, writeErr
|
||||
}
|
||||
|
||||
if bytesRead != bytesWritten {
|
||||
return totalBytesWritten, io.ErrShortWrite
|
||||
}
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
|
||||
if backupProgressListener != nil {
|
||||
currentSizeMB := float64(totalBytesWritten) / (1024 * 1024)
|
||||
if currentSizeMB >= lastReportedMB+progressReportIntervalMB {
|
||||
backupProgressListener(currentSizeMB)
|
||||
lastReportedMB = currentSizeMB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) createBackupContext(
|
||||
parentCtx context.Context,
|
||||
) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(parentCtx, backupTimeout)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(shutdownCheckInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
|
||||
metadata := usecases_common.BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
uc.logger.Info("Encryption disabled for backup", "backupId", backupID)
|
||||
return storageWriter, nil, metadata, nil
|
||||
}
|
||||
|
||||
salt, err := backup_encryption.GenerateSalt()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate salt: %w", err)
|
||||
}
|
||||
|
||||
nonce, err := backup_encryption.GenerateNonce()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to generate nonce: %w", err)
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to get master key: %w", err)
|
||||
}
|
||||
|
||||
encWriter, err := backup_encryption.NewEncryptionWriter(
|
||||
storageWriter,
|
||||
masterKey,
|
||||
backupID,
|
||||
salt,
|
||||
nonce,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, nil, metadata, fmt.Errorf("failed to create encrypting writer: %w", err)
|
||||
}
|
||||
|
||||
saltBase64 := base64.StdEncoding.EncodeToString(salt)
|
||||
nonceBase64 := base64.StdEncoding.EncodeToString(nonce)
|
||||
metadata.EncryptionSalt = &saltBase64
|
||||
metadata.EncryptionIV = &nonceBase64
|
||||
metadata.Encryption = backups_config.BackupEncryptionEncrypted
|
||||
|
||||
uc.logger.Info("Encryption enabled for backup", "backupId", backupID)
|
||||
return encWriter, encWriter, metadata, nil
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) cleanupOnCancellation(
|
||||
zstdWriter *zstd.Encoder,
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
saveErrCh chan error,
|
||||
) {
|
||||
if zstdWriter != nil {
|
||||
go func() {
|
||||
if closeErr := zstdWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close zstd writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
if closeErr := encryptionWriter.Close(); closeErr != nil {
|
||||
uc.logger.Error(
|
||||
"Failed to close encrypting writer during cancellation",
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer during cancellation", "error", err)
|
||||
}
|
||||
|
||||
<-saveErrCh
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) closeWriters(
|
||||
encryptionWriter *backup_encryption.EncryptionWriter,
|
||||
storageWriter io.WriteCloser,
|
||||
) error {
|
||||
encryptionCloseErrCh := make(chan error, 1)
|
||||
if encryptionWriter != nil {
|
||||
go func() {
|
||||
closeErr := encryptionWriter.Close()
|
||||
if closeErr != nil {
|
||||
uc.logger.Error("Failed to close encrypting writer", "error", closeErr)
|
||||
}
|
||||
encryptionCloseErrCh <- closeErr
|
||||
}()
|
||||
} else {
|
||||
encryptionCloseErrCh <- nil
|
||||
}
|
||||
|
||||
encryptionCloseErr := <-encryptionCloseErrCh
|
||||
if encryptionCloseErr != nil {
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer after encryption error", "error", err)
|
||||
}
|
||||
return fmt.Errorf("failed to close encryption writer: %w", encryptionCloseErr)
|
||||
}
|
||||
|
||||
if err := storageWriter.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close pipe writer", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) checkCancellationReason() error {
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("backup cancelled due to shutdown")
|
||||
}
|
||||
return fmt.Errorf("backup cancelled")
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) buildMysqldumpErrorMessage(
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
mysqlBin string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(mysqlBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
exitErr, ok := waitErr.(*exec.ExitError)
|
||||
if !ok {
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
exitCode := exitErr.ExitCode()
|
||||
|
||||
if exitCode == exitCodeGenericError || exitCode == exitCodeConnectionError {
|
||||
return uc.handleConnectionErrors(stderrStr)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func (uc *CreateMysqlBackupUsecase) handleConnectionErrors(stderrStr string) error {
|
||||
if containsIgnoreCase(stderrStr, "access denied") {
|
||||
return fmt.Errorf(
|
||||
"MySQL access denied. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "can't connect") ||
|
||||
containsIgnoreCase(stderrStr, "connection refused") {
|
||||
return fmt.Errorf(
|
||||
"MySQL connection refused. Check if the server is running and accessible. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "unknown database") {
|
||||
return fmt.Errorf(
|
||||
"MySQL database does not exist. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") {
|
||||
return fmt.Errorf(
|
||||
"MySQL SSL connection failed. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"MySQL connection timeout. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("MySQL connection or authentication error. stderr: %s", stderrStr)
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package usecases_mysql
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var createMysqlBackupUsecase = &CreateMysqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
secrets.GetSecretKeyService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
|
||||
func GetCreateMysqlBackupUsecase() *CreateMysqlBackupUsecase {
|
||||
return createMysqlBackupUsecase
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
backup_encryption "postgresus-backend/internal/features/backups/backups/encryption"
|
||||
usecases_common "postgresus-backend/internal/features/backups/backups/usecases/common"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
pgtypes "postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
@@ -30,7 +31,7 @@ import (
|
||||
const (
|
||||
backupTimeout = 23 * time.Hour
|
||||
shutdownCheckInterval = 1 * time.Second
|
||||
copyBufferSize = 32 * 1024
|
||||
copyBufferSize = 8 * 1024 * 1024
|
||||
progressReportIntervalMB = 1.0
|
||||
pgConnectTimeout = 30
|
||||
compressionLevel = 5
|
||||
@@ -45,7 +46,11 @@ type CreatePostgresqlBackupUsecase struct {
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
// Execute creates a backup of the database
|
||||
type writeResult struct {
|
||||
bytesWritten int
|
||||
writeErr error
|
||||
}
|
||||
|
||||
func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
ctx context.Context,
|
||||
backupID uuid.UUID,
|
||||
@@ -55,7 +60,7 @@ func (uc *CreatePostgresqlBackupUsecase) Execute(
|
||||
backupProgressListener func(
|
||||
completedMBs float64,
|
||||
),
|
||||
) (*BackupMetadata, error) {
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info(
|
||||
"Creating PostgreSQL backup via pg_dump custom format",
|
||||
"databaseId",
|
||||
@@ -114,7 +119,7 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
storage *storages.Storage,
|
||||
db *databases.Database,
|
||||
backupProgressListener func(completedMBs float64),
|
||||
) (*BackupMetadata, error) {
|
||||
) (*usecases_common.BackupMetadata, error) {
|
||||
uc.logger.Info("Streaming PostgreSQL backup to storage", "pgBin", pgBin, "args", args)
|
||||
|
||||
ctx, cancel := uc.createBackupContext(parentCtx)
|
||||
@@ -126,7 +131,8 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
}
|
||||
defer func() {
|
||||
if pgpassFile != "" {
|
||||
_ = os.Remove(pgpassFile)
|
||||
// Remove the entire temp directory (which contains the .pgpass file)
|
||||
_ = os.RemoveAll(filepath.Dir(pgpassFile))
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -165,14 +171,14 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
countingWriter := &CountingWriter{writer: finalWriter}
|
||||
countingWriter := usecases_common.NewCountingWriter(finalWriter)
|
||||
|
||||
// The backup ID becomes the object key / filename in storage
|
||||
|
||||
// Start streaming into storage in its own goroutine
|
||||
saveErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
saveErr := storage.SaveFile(uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErr := storage.SaveFile(ctx, uc.fieldEncryptor, uc.logger, backupID, storageReader)
|
||||
saveErrCh <- saveErr
|
||||
}()
|
||||
|
||||
@@ -195,12 +201,10 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
copyResultCh <- err
|
||||
}()
|
||||
|
||||
// Wait for the copy to finish first, then the dump process
|
||||
copyErr := <-copyResultCh
|
||||
bytesWritten := <-bytesWrittenCh
|
||||
waitErr := cmd.Wait()
|
||||
|
||||
// Check for shutdown or cancellation before finalizing
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
uc.cleanupOnCancellation(encryptionWriter, storageWriter, saveErrCh)
|
||||
@@ -213,7 +217,6 @@ func (uc *CreatePostgresqlBackupUsecase) streamToStorage(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Wait until storage ends reading
|
||||
saveErr := <-saveErrCh
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
@@ -267,7 +270,23 @@ func (uc *CreatePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
|
||||
bytesRead, readErr := src.Read(buf)
|
||||
if bytesRead > 0 {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
writeResultCh := make(chan writeResult, 1)
|
||||
go func() {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
writeResultCh <- writeResult{bytesWritten, writeErr}
|
||||
}()
|
||||
|
||||
var bytesWritten int
|
||||
var writeErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled during write: %w", ctx.Err())
|
||||
case result := <-writeResultCh:
|
||||
bytesWritten = result.bytesWritten
|
||||
writeErr = result.writeErr
|
||||
}
|
||||
|
||||
if bytesWritten < 0 || bytesRead < bytesWritten {
|
||||
bytesWritten = 0
|
||||
if writeErr == nil {
|
||||
@@ -316,6 +335,10 @@ func (uc *CreatePostgresqlBackupUsecase) buildPgDumpArgs(pg *pgtypes.PostgresqlD
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
for _, schema := range pg.IncludeSchemas {
|
||||
args = append(args, "-n", schema)
|
||||
}
|
||||
|
||||
compressionArgs := uc.getCompressionArgs(pg.Version)
|
||||
return append(args, compressionArgs...)
|
||||
}
|
||||
@@ -354,6 +377,9 @@ func (uc *CreatePostgresqlBackupUsecase) createBackupContext(
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-parentCtx.Done():
|
||||
cancel()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
@@ -445,8 +471,8 @@ func (uc *CreatePostgresqlBackupUsecase) setupBackupEncryption(
|
||||
backupID uuid.UUID,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
storageWriter io.WriteCloser,
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, BackupMetadata, error) {
|
||||
metadata := BackupMetadata{}
|
||||
) (io.Writer, *backup_encryption.EncryptionWriter, usecases_common.BackupMetadata, error) {
|
||||
metadata := usecases_common.BackupMetadata{}
|
||||
|
||||
if backupConfig.Encryption != backups_config.BackupEncryptionEncrypted {
|
||||
metadata.Encryption = backups_config.BackupEncryptionNone
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package usecases_postgresql
|
||||
|
||||
import "io"
|
||||
|
||||
// CountingWriter wraps an io.Writer and counts the bytes written to it
|
||||
type CountingWriter struct {
|
||||
writer io.Writer
|
||||
bytesWritten int64
|
||||
}
|
||||
|
||||
func (cw *CountingWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = cw.writer.Write(p)
|
||||
cw.bytesWritten += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
// GetBytesWritten returns the total number of bytes written
|
||||
func (cw *CountingWriter) GetBytesWritten() int64 {
|
||||
return cw.bytesWritten
|
||||
}
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
@@ -881,11 +882,9 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
// Verify password is encrypted
|
||||
assert.True(t, strings.HasPrefix(database.Postgresql.Password, "enc:"),
|
||||
"Password should be encrypted in database")
|
||||
|
||||
// Verify it can be decrypted back to original
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Postgresql.Password)
|
||||
assert.NoError(t, err)
|
||||
@@ -895,6 +894,55 @@ func Test_DatabaseSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
assert.Equal(t, "", database.Postgresql.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "MariaDB Database",
|
||||
databaseType: DatabaseTypeMariadb,
|
||||
createDatabase: func(workspaceID uuid.UUID) *Database {
|
||||
testDbName := "test_db"
|
||||
return &Database{
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Test MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: "localhost",
|
||||
Port: 3306,
|
||||
Username: "root",
|
||||
Password: "original-password-secret",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
},
|
||||
updateDatabase: func(workspaceID uuid.UUID, databaseID uuid.UUID) *Database {
|
||||
testDbName := "updated_test_db"
|
||||
return &Database{
|
||||
ID: databaseID,
|
||||
WorkspaceID: &workspaceID,
|
||||
Name: "Updated MariaDB Database",
|
||||
Type: DatabaseTypeMariadb,
|
||||
Mariadb: &mariadb.MariadbDatabase{
|
||||
Version: tools.MariadbVersion114,
|
||||
Host: "updated-host",
|
||||
Port: 3307,
|
||||
Username: "updated_user",
|
||||
Password: "",
|
||||
Database: &testDbName,
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, database *Database) {
|
||||
assert.True(t, strings.HasPrefix(database.Mariadb.Password, "enc:"),
|
||||
"Password should be encrypted in database")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
decrypted, err := encryptor.Decrypt(database.ID, database.Mariadb.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password-secret", decrypted)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, database *Database) {
|
||||
assert.Equal(t, "", database.Mariadb.Password)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
432
backend/internal/features/databases/databases/mariadb/model.go
Normal file
432
backend/internal/features/databases/databases/mariadb/model.go
Normal file
@@ -0,0 +1,432 @@
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type MariadbDatabase struct {
|
||||
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
|
||||
|
||||
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TableName() string {
|
||||
return "mariadb_databases"
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
}
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
if m.Password == "" {
|
||||
return errors.New("password is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return errors.New("database name is required for MariaDB backup")
|
||||
}
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to MariaDB database '%s': %w", *m.Database, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close MariaDB connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetConnMaxLifetime(15 * time.Second)
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("failed to ping MariaDB database '%s': %w", *m.Database, err)
|
||||
}
|
||||
|
||||
detectedVersion, err := detectMariadbVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) HideSensitiveData() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.Password = ""
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
||||
m.Version = incoming.Version
|
||||
m.Host = incoming.Host
|
||||
m.Port = incoming.Port
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) EncryptSensitiveFields(
|
||||
databaseID uuid.UUID,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if m.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(databaseID, m.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Password = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersionIfEmpty(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) PopulateVersion(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMariadbVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) CreateReadOnlyUser(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
maxRetries := 3
|
||||
for attempt := range maxRetries {
|
||||
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
|
||||
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
logger.Error("Failed to rollback transaction", "error", rollbackErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
newUsername,
|
||||
newPassword,
|
||||
))
|
||||
if err != nil {
|
||||
if attempt < maxRetries-1 {
|
||||
continue
|
||||
}
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
*m.Database,
|
||||
newUsername,
|
||||
))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant database privileges: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
newUsername,
|
||||
))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant PROCESS privilege: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "FLUSH PRIVILEGES")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to flush privileges: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
logger.Info(
|
||||
"Read-only MariaDB user created successfully",
|
||||
"username", newUsername,
|
||||
)
|
||||
return newUsername, newPassword, nil
|
||||
}
|
||||
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MariadbDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
)
|
||||
}
|
||||
|
||||
// detectMariadbVersion parses VERSION() output to detect MariaDB version
|
||||
// MariaDB returns strings like "10.11.6-MariaDB" or "11.4.2-MariaDB-1:11.4.2+maria~ubu2204"
|
||||
// Minor versions are mapped to the closest supported version (e.g., 12.1 → 12.0)
|
||||
func detectMariadbVersion(ctx context.Context, db *sql.DB) (tools.MariadbVersion, error) {
|
||||
var versionStr string
|
||||
err := db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&versionStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query MariaDB version: %w", err)
|
||||
}
|
||||
|
||||
if !strings.Contains(strings.ToLower(versionStr), "mariadb") {
|
||||
return "", fmt.Errorf(
|
||||
"not a MariaDB server (version: %s). Use MySQL database type instead",
|
||||
versionStr,
|
||||
)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`^(\d+)\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(versionStr)
|
||||
if len(matches) < 3 {
|
||||
return "", fmt.Errorf("could not parse MariaDB version: %s", versionStr)
|
||||
}
|
||||
|
||||
major := matches[1]
|
||||
minor := matches[2]
|
||||
|
||||
return mapMariadbVersion(major, minor)
|
||||
}
|
||||
|
||||
func mapMariadbVersion(major, minor string) (tools.MariadbVersion, error) {
|
||||
switch major {
|
||||
case "5":
|
||||
return tools.MariadbVersion55, nil
|
||||
case "10":
|
||||
return mapMariadb10xVersion(minor)
|
||||
case "11":
|
||||
return mapMariadb11xVersion(minor)
|
||||
case "12":
|
||||
return tools.MariadbVersion120, nil
|
||||
default:
|
||||
return "", fmt.Errorf(
|
||||
"unsupported MariaDB major version: %s (supported: 5.x, 10.x, 11.x, 12.x)",
|
||||
major,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func mapMariadb10xVersion(minor string) (tools.MariadbVersion, error) {
|
||||
switch minor {
|
||||
case "1":
|
||||
return tools.MariadbVersion101, nil
|
||||
case "2":
|
||||
return tools.MariadbVersion102, nil
|
||||
case "3":
|
||||
return tools.MariadbVersion103, nil
|
||||
case "4":
|
||||
return tools.MariadbVersion104, nil
|
||||
case "5":
|
||||
return tools.MariadbVersion105, nil
|
||||
case "6", "7", "8", "9", "10":
|
||||
return tools.MariadbVersion106, nil
|
||||
default:
|
||||
return tools.MariadbVersion1011, nil
|
||||
}
|
||||
}
|
||||
|
||||
func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
|
||||
switch minor {
|
||||
case "0", "1", "2", "3", "4":
|
||||
return tools.MariadbVersion114, nil
|
||||
case "5", "6", "7", "8":
|
||||
return tools.MariadbVersion118, nil
|
||||
default:
|
||||
return tools.MariadbVersion118, nil
|
||||
}
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, error) {
|
||||
if encryptor == nil {
|
||||
return password, nil
|
||||
}
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
package mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mariadbModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMariadbContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE readonly_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(
|
||||
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
readOnlyModel := &MariadbDatabase{
|
||||
Version: mariadbModel.Version,
|
||||
Host: mariadbModel.Host,
|
||||
Port: mariadbModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mariadbModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username,
|
||||
password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, dashDbName)
|
||||
dashDB, err := sqlx.Connect("mysql", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := &MariadbDatabase{
|
||||
Version: tools.MariadbVersion1011,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "pgs-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(dashDB, username)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMariadbContainer(t, env.TestMariadb1011Port, tools.MariadbVersion1011)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE drop_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mariadbModel := createMariadbModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mariadbModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
dropUserSafe(container.DB, username)
|
||||
}
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MariadbVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToMariadbContainer(
|
||||
t *testing.T,
|
||||
port string,
|
||||
version tools.MariadbVersion,
|
||||
) *MariadbContainer {
|
||||
if port == "" {
|
||||
t.Skipf("MariaDB port not configured for version %s", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Skipf("Failed to connect to MariaDB %s: %v", version, err)
|
||||
}
|
||||
|
||||
return &MariadbContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createMariadbModel(container *MariadbContainer) *MariadbDatabase {
|
||||
return &MariadbDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
|
||||
func dropUserSafe(db *sqlx.DB, username string) {
|
||||
// MariaDB 5.5 doesn't support DROP USER IF EXISTS, so we ignore errors
|
||||
_, _ = db.Exec(fmt.Sprintf("DROP USER '%s'@'%%'", username))
|
||||
}
|
||||
400
backend/internal/features/databases/databases/mysql/model.go
Normal file
400
backend/internal/features/databases/databases/mysql/model.go
Normal file
@@ -0,0 +1,400 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type MysqlDatabase struct {
|
||||
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
|
||||
|
||||
Version tools.MysqlVersion `json:"version" gorm:"type:text;not null"`
|
||||
|
||||
Host string `json:"host" gorm:"type:text;not null"`
|
||||
Port int `json:"port" gorm:"type:int;not null"`
|
||||
Username string `json:"username" gorm:"type:text;not null"`
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) TableName() string {
|
||||
return "mysql_databases"
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) Validate() error {
|
||||
if m.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
if m.Port == 0 {
|
||||
return errors.New("port is required")
|
||||
}
|
||||
if m.Username == "" {
|
||||
return errors.New("username is required")
|
||||
}
|
||||
if m.Password == "" {
|
||||
return errors.New("password is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) TestConnection(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return errors.New("database name is required for MySQL backup")
|
||||
}
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to MySQL database '%s': %w", *m.Database, err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close MySQL connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
db.SetConnMaxLifetime(15 * time.Second)
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
|
||||
if err := db.PingContext(ctx); err != nil {
|
||||
return fmt.Errorf("failed to ping MySQL database '%s': %w", *m.Database, err)
|
||||
}
|
||||
|
||||
detectedVersion, err := detectMysqlVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Version = detectedVersion
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) HideSensitiveData() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.Password = ""
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) Update(incoming *MysqlDatabase) {
|
||||
m.Version = incoming.Version
|
||||
m.Host = incoming.Host
|
||||
m.Port = incoming.Port
|
||||
m.Username = incoming.Username
|
||||
m.Database = incoming.Database
|
||||
m.IsHttps = incoming.IsHttps
|
||||
|
||||
if incoming.Password != "" {
|
||||
m.Password = incoming.Password
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) EncryptSensitiveFields(
|
||||
databaseID uuid.UUID,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if m.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(databaseID, m.Password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Password = encrypted
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersionIfEmpty(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return m.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) PopulateVersion(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if m.Database == nil || *m.Database == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectMysqlVersion(ctx, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Version = detectedVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) IsUserReadOnly(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (bool, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check grants: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
writePrivileges := []string{
|
||||
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
||||
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var grant string
|
||||
if err := rows.Scan(&grant); err != nil {
|
||||
return false, fmt.Errorf("failed to scan grant: %w", err)
|
||||
}
|
||||
|
||||
for _, priv := range writePrivileges {
|
||||
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return false, fmt.Errorf("error iterating grants: %w", err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) CreateReadOnlyUser(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, string, error) {
|
||||
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
dsn := m.buildDSN(password, *m.Database)
|
||||
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := db.Close(); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
maxRetries := 3
|
||||
for attempt := range maxRetries {
|
||||
newUsername := fmt.Sprintf("postgresus-%s", uuid.New().String()[:8])
|
||||
newPassword := uuid.New().String()
|
||||
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
|
||||
}
|
||||
|
||||
success := false
|
||||
defer func() {
|
||||
if !success {
|
||||
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
||||
logger.Error("Failed to rollback transaction", "error", rollbackErr)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
||||
newUsername,
|
||||
newPassword,
|
||||
))
|
||||
if err != nil {
|
||||
if attempt < maxRetries-1 {
|
||||
continue
|
||||
}
|
||||
return "", "", fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
||||
*m.Database,
|
||||
newUsername,
|
||||
))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant database privileges: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
||||
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
||||
newUsername,
|
||||
))
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant PROCESS privilege: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, "FLUSH PRIVILEGES")
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to flush privileges: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
|
||||
}
|
||||
|
||||
success = true
|
||||
logger.Info(
|
||||
"Read-only MySQL user created successfully",
|
||||
"username",
|
||||
newUsername,
|
||||
)
|
||||
return newUsername, newPassword, nil
|
||||
}
|
||||
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
}
|
||||
|
||||
func (m *MysqlDatabase) buildDSN(password string, database string) string {
|
||||
tlsConfig := "false"
|
||||
if m.IsHttps {
|
||||
tlsConfig = "true"
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
||||
m.Username,
|
||||
password,
|
||||
m.Host,
|
||||
m.Port,
|
||||
database,
|
||||
tlsConfig,
|
||||
)
|
||||
}
|
||||
|
||||
// detectMysqlVersion parses VERSION() output to detect MySQL version
|
||||
// Minor versions are mapped to the closest supported version (e.g., 8.1 → 8.0, 8.4+ → 8.4)
|
||||
func detectMysqlVersion(ctx context.Context, db *sql.DB) (tools.MysqlVersion, error) {
|
||||
var versionStr string
|
||||
err := db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&versionStr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to query MySQL version: %w", err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`^(\d+)\.(\d+)`)
|
||||
matches := re.FindStringSubmatch(versionStr)
|
||||
if len(matches) < 3 {
|
||||
return "", fmt.Errorf("could not parse MySQL version: %s", versionStr)
|
||||
}
|
||||
|
||||
major := matches[1]
|
||||
minor := matches[2]
|
||||
|
||||
return mapMysqlVersion(major, minor)
|
||||
}
|
||||
|
||||
func mapMysqlVersion(major, minor string) (tools.MysqlVersion, error) {
|
||||
switch major {
|
||||
case "5":
|
||||
return tools.MysqlVersion57, nil
|
||||
case "8":
|
||||
return mapMysql8xVersion(minor), nil
|
||||
case "9":
|
||||
return tools.MysqlVersion84, nil
|
||||
default:
|
||||
return "", fmt.Errorf(
|
||||
"unsupported MySQL major version: %s (supported: 5.x, 8.x, 9.x)",
|
||||
major,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func mapMysql8xVersion(minor string) tools.MysqlVersion {
|
||||
switch minor {
|
||||
case "0", "1", "2", "3":
|
||||
return tools.MysqlVersion80
|
||||
default:
|
||||
return tools.MysqlVersion84
|
||||
}
|
||||
}
|
||||
|
||||
func decryptPasswordIfNeeded(
|
||||
password string,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) (string, error) {
|
||||
if encryptor == nil {
|
||||
return password, nil
|
||||
}
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
func Test_IsUserReadOnly_AdminUser_ReturnsFalse(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
isReadOnly, err := mysqlModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.False(t, isReadOnly, "Root user should not be read-only")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
container := connectToMysqlContainer(t, tc.port, tc.version)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS readonly_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS hack_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE readonly_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(
|
||||
`INSERT INTO readonly_test (data) VALUES ('test1'), ('test2')`,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "postgresus-"))
|
||||
|
||||
readOnlyModel := &MysqlDatabase{
|
||||
Version: mysqlModel.Version,
|
||||
Host: mysqlModel.Host,
|
||||
Port: mysqlModel.Port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: mysqlModel.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
isReadOnly, err := readOnlyModel.IsUserReadOnly(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isReadOnly, "Created user should be read-only")
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username,
|
||||
password,
|
||||
container.Host,
|
||||
container.Port,
|
||||
container.Database,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM readonly_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO readonly_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("UPDATE readonly_test SET data = 'hacked' WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("DELETE FROM readonly_test WHERE id = 1")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_FutureTables_NoSelectPermission(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(`DROP TABLE IF EXISTS future_table`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE future_table (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO future_table (data) VALUES ('future_data')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var data string
|
||||
err = readOnlyConn.Get(&data, "SELECT data FROM future_table LIMIT 1")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "future_data", data)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, dashDbName)
|
||||
dashDB, err := sqlx.Connect("mysql", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = dashDB.Exec(`INSERT INTO dash_test (data) VALUES ('test1'), ('test2')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mysqlModel := &MysqlDatabase{
|
||||
Version: tools.MysqlVersion80,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "postgresus-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_ReadOnlyUser_CannotDropOrAlterTables(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToMysqlContainer(t, env.TestMysql80Port, tools.MysqlVersion80)
|
||||
defer container.DB.Close()
|
||||
|
||||
_, err := container.DB.Exec(`DROP TABLE IF EXISTS drop_test`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`
|
||||
CREATE TABLE drop_test (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
data VARCHAR(255) NOT NULL
|
||||
)
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
_, err = container.DB.Exec(`INSERT INTO drop_test (data) VALUES ('test1')`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
mysqlModel := createMysqlModel(container)
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := mysqlModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, container.Host, container.Port, container.Database)
|
||||
readOnlyConn, err := sqlx.Connect("mysql", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
_, err = readOnlyConn.Exec("DROP TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("ALTER TABLE drop_test ADD COLUMN new_col VARCHAR(100)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("TRUNCATE TABLE drop_test")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, strings.ToLower(err.Error()), "denied")
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP USER IF EXISTS '%s'@'%%'", username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MysqlVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
func connectToMysqlContainer(
|
||||
t *testing.T,
|
||||
port string,
|
||||
version tools.MysqlVersion,
|
||||
) *MysqlContainer {
|
||||
if port == "" {
|
||||
t.Skipf("MySQL port not configured for version %s", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
host := "127.0.0.1"
|
||||
username := "root"
|
||||
password := "rootpassword"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
t.Skipf("Failed to connect to MySQL %s: %v", version, err)
|
||||
}
|
||||
|
||||
return &MysqlContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
func createMysqlModel(container *MysqlContainer) *MysqlDatabase {
|
||||
return &MysqlDatabase{
|
||||
Version: container.Version,
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &container.Database,
|
||||
IsHttps: false,
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type PostgresqlDatabase struct {
|
||||
@@ -29,17 +30,40 @@ type PostgresqlDatabase struct {
|
||||
Password string `json:"password" gorm:"type:text;not null"`
|
||||
Database *string `json:"database" gorm:"type:text"`
|
||||
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
||||
|
||||
// backup settings
|
||||
IncludeSchemas []string `json:"includeSchemas" gorm:"-"`
|
||||
IncludeSchemasString string `json:"-" gorm:"column:include_schemas;type:text;not null;default:''"`
|
||||
|
||||
// restore settings (not saved to DB)
|
||||
IsExcludeExtensions bool `json:"isExcludeExtensions" gorm:"-"`
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) TableName() string {
|
||||
return "postgresql_databases"
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) Validate() error {
|
||||
if p.Version == "" {
|
||||
return errors.New("version is required")
|
||||
func (p *PostgresqlDatabase) BeforeSave(_ *gorm.DB) error {
|
||||
if len(p.IncludeSchemas) > 0 {
|
||||
p.IncludeSchemasString = strings.Join(p.IncludeSchemas, ",")
|
||||
} else {
|
||||
p.IncludeSchemasString = ""
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) AfterFind(_ *gorm.DB) error {
|
||||
if p.IncludeSchemasString != "" {
|
||||
p.IncludeSchemas = strings.Split(p.IncludeSchemasString, ",")
|
||||
} else {
|
||||
p.IncludeSchemas = []string{}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *PostgresqlDatabase) Validate() error {
|
||||
if p.Host == "" {
|
||||
return errors.New("host is required")
|
||||
}
|
||||
@@ -85,6 +109,7 @@ func (p *PostgresqlDatabase) Update(incoming *PostgresqlDatabase) {
|
||||
p.Username = incoming.Username
|
||||
p.Database = incoming.Database
|
||||
p.IsHttps = incoming.IsHttps
|
||||
p.IncludeSchemas = incoming.IncludeSchemas
|
||||
|
||||
if incoming.Password != "" {
|
||||
p.Password = incoming.Password
|
||||
@@ -106,6 +131,58 @@ func (p *PostgresqlDatabase) EncryptSensitiveFields(
|
||||
return nil
|
||||
}
|
||||
|
||||
// PopulateVersionIfEmpty detects and sets the PostgreSQL version if not already set.
|
||||
// This should be called before encrypting sensitive fields.
|
||||
func (p *PostgresqlDatabase) PopulateVersionIfEmpty(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if p.Version != "" {
|
||||
return nil
|
||||
}
|
||||
return p.PopulateVersion(logger, encryptor, databaseID)
|
||||
}
|
||||
|
||||
// PopulateVersion detects and sets the PostgreSQL version by querying the database.
|
||||
func (p *PostgresqlDatabase) PopulateVersion(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
databaseID uuid.UUID,
|
||||
) error {
|
||||
if p.Database == nil || *p.Database == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
password, err := decryptPasswordIfNeeded(p.Password, encryptor, databaseID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
connStr := buildConnectionStringForDB(p, *p.Database, password)
|
||||
|
||||
conn, err := pgx.Connect(ctx, connStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := conn.Close(ctx); closeErr != nil {
|
||||
logger.Error("Failed to close connection", "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
detectedVersion, err := detectDatabaseVersion(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.Version = detectedVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsUserReadOnly checks if the database user has read-only privileges.
|
||||
//
|
||||
// This method performs a comprehensive security check by examining:
|
||||
@@ -286,8 +363,20 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
|
||||
// Retry logic for username collision
|
||||
maxRetries := 3
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
username := fmt.Sprintf("postgresus-%s", uuid.New().String()[:8])
|
||||
for attempt := range maxRetries {
|
||||
// Generate base username for PostgreSQL user creation
|
||||
baseUsername := fmt.Sprintf("postgresus-%s", uuid.New().String()[:8])
|
||||
|
||||
// For Supabase session pooler, the username format for connection is "username.projectid"
|
||||
// but the actual PostgreSQL user must be created with just the base name.
|
||||
// The pooler will strip the ".projectid" suffix when authenticating.
|
||||
connectionUsername := baseUsername
|
||||
if isSupabaseConnection(p.Host, p.Username) {
|
||||
if supabaseProjectID := extractSupabaseProjectID(p.Username); supabaseProjectID != "" {
|
||||
connectionUsername = fmt.Sprintf("%s.%s", baseUsername, supabaseProjectID)
|
||||
}
|
||||
}
|
||||
|
||||
newPassword := uuid.New().String()
|
||||
|
||||
tx, err := conn.Begin(ctx)
|
||||
@@ -305,9 +394,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}()
|
||||
|
||||
// Step 1: Create PostgreSQL user with LOGIN privilege
|
||||
// Note: We use baseUsername for the actual PostgreSQL user name if Supabase is used
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`, username, newPassword),
|
||||
fmt.Sprintf(`CREATE USER "%s" WITH PASSWORD '%s' LOGIN`, baseUsername, newPassword),
|
||||
)
|
||||
if err != nil {
|
||||
if err.Error() != "" && attempt < maxRetries-1 {
|
||||
@@ -331,28 +421,28 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
|
||||
// Now revoke from the specific user as well (belt and suspenders)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, username))
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE CREATE ON SCHEMA public FROM "%s"`, baseUsername))
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
"Failed to revoke CREATE on public schema from user",
|
||||
"error",
|
||||
err,
|
||||
"username",
|
||||
username,
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Step 2: Grant database connection privilege and revoke TEMP
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE %s TO "%s"`, *p.Database, username),
|
||||
fmt.Sprintf(`GRANT CONNECT ON DATABASE "%s" TO "%s"`, *p.Database, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant connect privilege: %w", err)
|
||||
}
|
||||
|
||||
// Revoke TEMP privilege from PUBLIC role (like CREATE on public schema, TEMP is granted to PUBLIC by default)
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE TEMP ON DATABASE %s FROM PUBLIC`, *p.Database))
|
||||
_, err = tx.Exec(ctx, fmt.Sprintf(`REVOKE TEMP ON DATABASE "%s" FROM PUBLIC`, *p.Database))
|
||||
if err != nil {
|
||||
logger.Warn("Failed to revoke TEMP from PUBLIC", "error", err)
|
||||
}
|
||||
@@ -360,10 +450,10 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
// Also revoke from the specific user (belt and suspenders)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE TEMP ON DATABASE %s FROM "%s"`, *p.Database, username),
|
||||
fmt.Sprintf(`REVOKE TEMP ON DATABASE "%s" FROM "%s"`, *p.Database, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", username)
|
||||
logger.Warn("Failed to revoke TEMP privilege", "error", err, "username", baseUsername)
|
||||
}
|
||||
|
||||
// Step 3: Discover all user-created schemas
|
||||
@@ -396,7 +486,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
// Revoke CREATE specifically (handles inheritance from PUBLIC role)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA "%s" FROM "%s"`, schema, username),
|
||||
fmt.Sprintf(`REVOKE CREATE ON SCHEMA "%s" FROM "%s"`, schema, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(
|
||||
@@ -406,14 +496,14 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
"schema",
|
||||
schema,
|
||||
"username",
|
||||
username,
|
||||
baseUsername,
|
||||
)
|
||||
}
|
||||
|
||||
// Grant only USAGE (not CREATE)
|
||||
_, err = tx.Exec(
|
||||
ctx,
|
||||
fmt.Sprintf(`GRANT USAGE ON SCHEMA "%s" TO "%s"`, schema, username),
|
||||
fmt.Sprintf(`GRANT USAGE ON SCHEMA "%s" TO "%s"`, schema, baseUsername),
|
||||
)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to grant usage on schema %s: %w", schema, err)
|
||||
@@ -435,7 +525,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
EXECUTE format('GRANT SELECT ON ALL SEQUENCES IN SCHEMA %%I TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, username, username)
|
||||
`, baseUsername, baseUsername)
|
||||
|
||||
_, err = tx.Exec(ctx, grantSelectSQL)
|
||||
if err != nil {
|
||||
@@ -457,7 +547,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
EXECUTE format('ALTER DEFAULT PRIVILEGES IN SCHEMA %%I GRANT SELECT ON SEQUENCES TO "%s"', schema_rec.schema_name);
|
||||
END LOOP;
|
||||
END $$;
|
||||
`, username, username)
|
||||
`, baseUsername, baseUsername)
|
||||
|
||||
_, err = tx.Exec(ctx, defaultPrivilegesSQL)
|
||||
if err != nil {
|
||||
@@ -466,7 +556,7 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
|
||||
// Step 7: Verify user creation before committing
|
||||
var verifyUsername string
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, username)).
|
||||
err = tx.QueryRow(ctx, fmt.Sprintf(`SELECT rolname FROM pg_roles WHERE rolname = '%s'`, baseUsername)).
|
||||
Scan(&verifyUsername)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("failed to verify user creation: %w", err)
|
||||
@@ -477,8 +567,15 @@ func (p *PostgresqlDatabase) CreateReadOnlyUser(
|
||||
}
|
||||
|
||||
success = true
|
||||
logger.Info("Read-only user created successfully", "username", username)
|
||||
return username, newPassword, nil
|
||||
// Return connectionUsername (with project ID suffix for Supabase) for the caller to use when connecting
|
||||
logger.Info(
|
||||
"Read-only user created successfully",
|
||||
"username",
|
||||
baseUsername,
|
||||
"connectionUsername",
|
||||
connectionUsername,
|
||||
)
|
||||
return connectionUsername, newPassword, nil
|
||||
}
|
||||
|
||||
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
||||
@@ -521,10 +618,12 @@ func testSingleDatabaseConnection(
|
||||
}
|
||||
}()
|
||||
|
||||
// Check version after successful connection
|
||||
if err := verifyDatabaseVersion(ctx, conn, postgresDb.Version); err != nil {
|
||||
// Detect and set the database version automatically
|
||||
detectedVersion, err := detectDatabaseVersion(ctx, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
postgresDb.Version = detectedVersion
|
||||
|
||||
// Test if we can perform basic operations (like pg_dump would need)
|
||||
if err := testBasicOperations(ctx, conn, *postgresDb.Database); err != nil {
|
||||
@@ -538,35 +637,31 @@ func testSingleDatabaseConnection(
|
||||
return nil
|
||||
}
|
||||
|
||||
// verifyDatabaseVersion checks if the actual database version matches the specified version
|
||||
func verifyDatabaseVersion(
|
||||
ctx context.Context,
|
||||
conn *pgx.Conn,
|
||||
expectedVersion tools.PostgresqlVersion,
|
||||
) error {
|
||||
// detectDatabaseVersion queries and returns the PostgreSQL major version
|
||||
func detectDatabaseVersion(ctx context.Context, conn *pgx.Conn) (tools.PostgresqlVersion, error) {
|
||||
var versionStr string
|
||||
err := conn.QueryRow(ctx, "SELECT version()").Scan(&versionStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query database version: %w", err)
|
||||
return "", fmt.Errorf("failed to query database version: %w", err)
|
||||
}
|
||||
|
||||
// Parse version from string like "PostgreSQL 14.2 on x86_64-pc-linux-gnu..."
|
||||
re := regexp.MustCompile(`PostgreSQL (\d+)\.`)
|
||||
// or "PostgreSQL 16 maintained by Postgre BY..." (some builds omit minor version)
|
||||
re := regexp.MustCompile(`PostgreSQL (\d+)`)
|
||||
matches := re.FindStringSubmatch(versionStr)
|
||||
if len(matches) < 2 {
|
||||
return fmt.Errorf("could not parse version from: %s", versionStr)
|
||||
return "", fmt.Errorf("could not parse version from: %s", versionStr)
|
||||
}
|
||||
|
||||
actualVersion := tools.GetPostgresqlVersionEnum(matches[1])
|
||||
if actualVersion != expectedVersion {
|
||||
return fmt.Errorf(
|
||||
"you specified wrong version. Real version is %s, but you specified %s",
|
||||
actualVersion,
|
||||
expectedVersion,
|
||||
)
|
||||
}
|
||||
majorVersion := matches[1]
|
||||
|
||||
return nil
|
||||
// Map to known PostgresqlVersion enum values
|
||||
switch majorVersion {
|
||||
case "12", "13", "14", "15", "16", "17", "18":
|
||||
return tools.PostgresqlVersion(majorVersion), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported PostgreSQL version: %s", majorVersion)
|
||||
}
|
||||
}
|
||||
|
||||
// testBasicOperations tests basic operations that backup tools need
|
||||
@@ -594,7 +689,7 @@ func buildConnectionStringForDB(p *PostgresqlDatabase, dbName string, password s
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol",
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s default_query_exec_mode=simple_protocol standard_conforming_strings=on client_encoding=UTF8",
|
||||
p.Host,
|
||||
p.Port,
|
||||
p.Username,
|
||||
@@ -614,3 +709,15 @@ func decryptPasswordIfNeeded(
|
||||
}
|
||||
return encryptor.Decrypt(databaseID, password)
|
||||
}
|
||||
|
||||
func isSupabaseConnection(host, username string) bool {
|
||||
return strings.Contains(strings.ToLower(host), "supabase") ||
|
||||
strings.Contains(strings.ToLower(username), "supabase")
|
||||
}
|
||||
|
||||
func extractSupabaseProjectID(username string) string {
|
||||
if idx := strings.Index(username, "."); idx != -1 {
|
||||
return username[idx+1:]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -246,6 +246,188 @@ func Test_ReadOnlyUser_MultipleSchemas_AllAccessible(t *testing.T) {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_DatabaseNameWithDash_Success(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
container := connectToPostgresContainer(t, env.TestPostgres16Port)
|
||||
defer container.DB.Close()
|
||||
|
||||
dashDbName := "test-db-with-dash"
|
||||
|
||||
_, err := container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf(`CREATE DATABASE "%s"`, dashDbName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = container.DB.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS "%s"`, dashDbName))
|
||||
}()
|
||||
|
||||
dashDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, container.Username, container.Password, dashDbName)
|
||||
dashDB, err := sqlx.Connect("postgres", dashDSN)
|
||||
assert.NoError(t, err)
|
||||
defer dashDB.Close()
|
||||
|
||||
_, err = dashDB.Exec(`
|
||||
CREATE TABLE dash_test (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO dash_test (data) VALUES ('test1'), ('test2');
|
||||
`)
|
||||
assert.NoError(t, err)
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Version: tools.GetPostgresqlVersionEnum("16"),
|
||||
Host: container.Host,
|
||||
Port: container.Port,
|
||||
Username: container.Username,
|
||||
Password: container.Password,
|
||||
Database: &dashDbName,
|
||||
IsHttps: false,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
username, password, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, username)
|
||||
assert.NotEmpty(t, password)
|
||||
assert.True(t, strings.HasPrefix(username, "postgresus-"))
|
||||
|
||||
readOnlyDSN := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
container.Host, container.Port, username, password, dashDbName)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, "SELECT COUNT(*) FROM dash_test")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec("INSERT INTO dash_test (data) VALUES ('should-fail')")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, username))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to drop owned objects: %v", err)
|
||||
}
|
||||
|
||||
_, err = dashDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, username))
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_CreateReadOnlyUser_Supabase_UserCanReadButNotWrite(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
|
||||
if env.TestSupabaseHost == "" {
|
||||
t.Skip("Skipping Supabase test: missing environment variables")
|
||||
}
|
||||
|
||||
portInt, err := strconv.Atoi(env.TestSupabasePort)
|
||||
assert.NoError(t, err)
|
||||
|
||||
dsn := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
env.TestSupabaseUsername,
|
||||
env.TestSupabasePassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
|
||||
adminDB, err := sqlx.Connect("postgres", dsn)
|
||||
assert.NoError(t, err)
|
||||
defer adminDB.Close()
|
||||
|
||||
tableName := fmt.Sprintf(
|
||||
"readonly_test_%s",
|
||||
strings.ReplaceAll(uuid.New().String()[:8], "-", ""),
|
||||
)
|
||||
_, err = adminDB.Exec(fmt.Sprintf(`
|
||||
DROP TABLE IF EXISTS public.%s CASCADE;
|
||||
CREATE TABLE public.%s (
|
||||
id SERIAL PRIMARY KEY,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
INSERT INTO public.%s (data) VALUES ('test1'), ('test2');
|
||||
`, tableName, tableName, tableName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP TABLE IF EXISTS public.%s CASCADE`, tableName))
|
||||
}()
|
||||
|
||||
pgModel := &PostgresqlDatabase{
|
||||
Host: env.TestSupabaseHost,
|
||||
Port: portInt,
|
||||
Username: env.TestSupabaseUsername,
|
||||
Password: env.TestSupabasePassword,
|
||||
Database: &env.TestSupabaseDatabase,
|
||||
IsHttps: true,
|
||||
}
|
||||
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, nil))
|
||||
ctx := context.Background()
|
||||
|
||||
connectionUsername, newPassword, err := pgModel.CreateReadOnlyUser(ctx, logger, nil, uuid.New())
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, connectionUsername)
|
||||
assert.NotEmpty(t, newPassword)
|
||||
assert.True(t, strings.HasPrefix(connectionUsername, "postgresus-"))
|
||||
|
||||
baseUsername := connectionUsername
|
||||
if idx := strings.Index(connectionUsername, "."); idx != -1 {
|
||||
baseUsername = connectionUsername[:idx]
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP OWNED BY "%s" CASCADE`, baseUsername))
|
||||
_, _ = adminDB.Exec(fmt.Sprintf(`DROP USER IF EXISTS "%s"`, baseUsername))
|
||||
}()
|
||||
|
||||
readOnlyDSN := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=require",
|
||||
env.TestSupabaseHost,
|
||||
portInt,
|
||||
connectionUsername,
|
||||
newPassword,
|
||||
env.TestSupabaseDatabase,
|
||||
)
|
||||
readOnlyConn, err := sqlx.Connect("postgres", readOnlyDSN)
|
||||
assert.NoError(t, err)
|
||||
defer readOnlyConn.Close()
|
||||
|
||||
var count int
|
||||
err = readOnlyConn.Get(&count, fmt.Sprintf("SELECT COUNT(*) FROM public.%s", tableName))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, count)
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("INSERT INTO public.%s (data) VALUES ('should-fail')", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(
|
||||
fmt.Sprintf("UPDATE public.%s SET data = 'hacked' WHERE id = 1", tableName),
|
||||
)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec(fmt.Sprintf("DELETE FROM public.%s WHERE id = 1", tableName))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
|
||||
_, err = readOnlyConn.Exec("CREATE TABLE public.hack_table (id INT)")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied")
|
||||
}
|
||||
|
||||
type PostgresContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
|
||||
@@ -4,6 +4,8 @@ type DatabaseType string
|
||||
|
||||
const (
|
||||
DatabaseTypePostgres DatabaseType = "POSTGRES"
|
||||
DatabaseTypeMysql DatabaseType = "MYSQL"
|
||||
DatabaseTypeMariadb DatabaseType = "MARIADB"
|
||||
)
|
||||
|
||||
type HealthStatus string
|
||||
|
||||
@@ -3,6 +3,8 @@ package databases
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/databases/databases/mysql"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
@@ -21,6 +23,8 @@ type Database struct {
|
||||
Type DatabaseType `json:"type" gorm:"column:type;type:text;not null"`
|
||||
|
||||
Postgresql *postgresql.PostgresqlDatabase `json:"postgresql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Mysql *mysql.MysqlDatabase `json:"mysql,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
Mariadb *mariadb.MariadbDatabase `json:"mariadb,omitempty" gorm:"foreignKey:DatabaseID"`
|
||||
|
||||
Notifiers []notifiers.Notifier `json:"notifiers" gorm:"many2many:database_notifiers;"`
|
||||
|
||||
@@ -42,8 +46,17 @@ func (d *Database) Validate() error {
|
||||
if d.Postgresql == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
|
||||
return d.Postgresql.Validate()
|
||||
case DatabaseTypeMysql:
|
||||
if d.Mysql == nil {
|
||||
return errors.New("mysql database is required")
|
||||
}
|
||||
return d.Mysql.Validate()
|
||||
case DatabaseTypeMariadb:
|
||||
if d.Mariadb == nil {
|
||||
return errors.New("mariadb database is required")
|
||||
}
|
||||
return d.Mariadb.Validate()
|
||||
default:
|
||||
return errors.New("invalid database type: " + string(d.Type))
|
||||
}
|
||||
@@ -72,6 +85,28 @@ func (d *Database) EncryptSensitiveFields(encryptor encryption.FieldEncryptor) e
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
if d.Mysql != nil {
|
||||
return d.Mysql.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
if d.Mariadb != nil {
|
||||
return d.Mariadb.EncryptSensitiveFields(d.ID, encryptor)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Database) PopulateVersionIfEmpty(
|
||||
logger *slog.Logger,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) error {
|
||||
if d.Postgresql != nil {
|
||||
return d.Postgresql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mysql != nil {
|
||||
return d.Mysql.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
}
|
||||
if d.Mariadb != nil {
|
||||
return d.Mariadb.PopulateVersionIfEmpty(logger, encryptor, d.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -85,6 +120,14 @@ func (d *Database) Update(incoming *Database) {
|
||||
if d.Postgresql != nil && incoming.Postgresql != nil {
|
||||
d.Postgresql.Update(incoming.Postgresql)
|
||||
}
|
||||
case DatabaseTypeMysql:
|
||||
if d.Mysql != nil && incoming.Mysql != nil {
|
||||
d.Mysql.Update(incoming.Mysql)
|
||||
}
|
||||
case DatabaseTypeMariadb:
|
||||
if d.Mariadb != nil && incoming.Mariadb != nil {
|
||||
d.Mariadb.Update(incoming.Mariadb)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -92,6 +135,10 @@ func (d *Database) getSpecificDatabase() DatabaseConnector {
|
||||
switch d.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return d.Postgresql
|
||||
case DatabaseTypeMysql:
|
||||
return d.Mysql
|
||||
case DatabaseTypeMariadb:
|
||||
return d.Mariadb
|
||||
}
|
||||
|
||||
panic("invalid database type: " + string(d.Type))
|
||||
|
||||
@@ -2,6 +2,8 @@ package databases
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/databases/databases/mysql"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/storage"
|
||||
|
||||
@@ -25,26 +27,33 @@ func (r *DatabaseRepository) Save(database *Database) (*Database, error) {
|
||||
if database.Postgresql == nil {
|
||||
return errors.New("postgresql configuration is required for PostgreSQL database")
|
||||
}
|
||||
|
||||
// Ensure DatabaseID is always set and never nil
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
case DatabaseTypeMysql:
|
||||
if database.Mysql == nil {
|
||||
return errors.New("mysql configuration is required for MySQL database")
|
||||
}
|
||||
database.Mysql.DatabaseID = &database.ID
|
||||
case DatabaseTypeMariadb:
|
||||
if database.Mariadb == nil {
|
||||
return errors.New("mariadb configuration is required for MariaDB database")
|
||||
}
|
||||
database.Mariadb.DatabaseID = &database.ID
|
||||
}
|
||||
|
||||
if isNew {
|
||||
if err := tx.Create(database).
|
||||
Omit("Postgresql", "Notifiers").
|
||||
Omit("Postgresql", "Mysql", "Mariadb", "Notifiers").
|
||||
Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(database).
|
||||
Omit("Postgresql", "Notifiers").
|
||||
Omit("Postgresql", "Mysql", "Mariadb", "Notifiers").
|
||||
Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Save the specific database type
|
||||
switch database.Type {
|
||||
case DatabaseTypePostgres:
|
||||
database.Postgresql.DatabaseID = &database.ID
|
||||
@@ -58,6 +67,30 @@ func (r *DatabaseRepository) Save(database *Database) (*Database, error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case DatabaseTypeMysql:
|
||||
database.Mysql.DatabaseID = &database.ID
|
||||
if database.Mysql.ID == uuid.Nil {
|
||||
database.Mysql.ID = uuid.New()
|
||||
if err := tx.Create(database.Mysql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(database.Mysql).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case DatabaseTypeMariadb:
|
||||
database.Mariadb.DatabaseID = &database.ID
|
||||
if database.Mariadb.ID == uuid.Nil {
|
||||
database.Mariadb.ID = uuid.New()
|
||||
if err := tx.Create(database.Mariadb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(database.Mariadb).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.
|
||||
@@ -83,6 +116,8 @@ func (r *DatabaseRepository) FindByID(id uuid.UUID) (*Database, error) {
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Postgresql").
|
||||
Preload("Mysql").
|
||||
Preload("Mariadb").
|
||||
Preload("Notifiers").
|
||||
Where("id = ?", id).
|
||||
First(&database).Error; err != nil {
|
||||
@@ -98,6 +133,8 @@ func (r *DatabaseRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Databa
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Postgresql").
|
||||
Preload("Mysql").
|
||||
Preload("Mariadb").
|
||||
Preload("Notifiers").
|
||||
Where("workspace_id = ?", workspaceID).
|
||||
Order("CASE WHEN health_status = 'UNAVAILABLE' THEN 1 WHEN health_status = 'AVAILABLE' THEN 2 WHEN health_status IS NULL THEN 3 ELSE 4 END, name ASC").
|
||||
@@ -128,6 +165,18 @@ func (r *DatabaseRepository) Delete(id uuid.UUID) error {
|
||||
Delete(&postgresql.PostgresqlDatabase{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
case DatabaseTypeMysql:
|
||||
if err := tx.
|
||||
Where("database_id = ?", id).
|
||||
Delete(&mysql.MysqlDatabase{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
case DatabaseTypeMariadb:
|
||||
if err := tx.
|
||||
Where("database_id = ?", id).
|
||||
Delete(&mariadb.MariadbDatabase{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Delete(&Database{}, id).Error; err != nil {
|
||||
@@ -158,6 +207,8 @@ func (r *DatabaseRepository) GetAllDatabases() ([]*Database, error) {
|
||||
if err := storage.
|
||||
GetDb().
|
||||
Preload("Postgresql").
|
||||
Preload("Mysql").
|
||||
Preload("Mariadb").
|
||||
Preload("Notifiers").
|
||||
Find(&databases).Error; err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
"postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/databases/databases/mysql"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
"postgresus-backend/internal/features/notifiers"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
@@ -68,6 +70,10 @@ func (s *DatabaseService) CreateDatabase(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := database.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
}
|
||||
|
||||
if err := database.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return nil, fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -125,6 +131,10 @@ func (s *DatabaseService) UpdateDatabase(
|
||||
return err
|
||||
}
|
||||
|
||||
if err := existingDatabase.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
}
|
||||
|
||||
if err := existingDatabase.EncryptSensitiveFields(s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to encrypt sensitive fields: %w", err)
|
||||
}
|
||||
@@ -396,6 +406,34 @@ func (s *DatabaseService) CopyDatabase(
|
||||
IsHttps: existingDatabase.Postgresql.IsHttps,
|
||||
}
|
||||
}
|
||||
case DatabaseTypeMysql:
|
||||
if existingDatabase.Mysql != nil {
|
||||
newDatabase.Mysql = &mysql.MysqlDatabase{
|
||||
ID: uuid.Nil,
|
||||
DatabaseID: nil,
|
||||
Version: existingDatabase.Mysql.Version,
|
||||
Host: existingDatabase.Mysql.Host,
|
||||
Port: existingDatabase.Mysql.Port,
|
||||
Username: existingDatabase.Mysql.Username,
|
||||
Password: existingDatabase.Mysql.Password,
|
||||
Database: existingDatabase.Mysql.Database,
|
||||
IsHttps: existingDatabase.Mysql.IsHttps,
|
||||
}
|
||||
}
|
||||
case DatabaseTypeMariadb:
|
||||
if existingDatabase.Mariadb != nil {
|
||||
newDatabase.Mariadb = &mariadb.MariadbDatabase{
|
||||
ID: uuid.Nil,
|
||||
DatabaseID: nil,
|
||||
Version: existingDatabase.Mariadb.Version,
|
||||
Host: existingDatabase.Mariadb.Host,
|
||||
Port: existingDatabase.Mariadb.Port,
|
||||
Username: existingDatabase.Mariadb.Username,
|
||||
Password: existingDatabase.Mariadb.Password,
|
||||
Database: existingDatabase.Mariadb.Database,
|
||||
IsHttps: existingDatabase.Mariadb.IsHttps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := newDatabase.Validate(); err != nil {
|
||||
@@ -510,19 +548,34 @@ func (s *DatabaseService) IsUserReadOnly(
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
if usingDatabase.Type != DatabaseTypePostgres {
|
||||
return false, errors.New("read-only check only supported for PostgreSQL databases")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return usingDatabase.Postgresql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
switch usingDatabase.Type {
|
||||
case DatabaseTypePostgres:
|
||||
return usingDatabase.Postgresql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMysql:
|
||||
return usingDatabase.Mysql.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMariadb:
|
||||
return usingDatabase.Mariadb.IsUserReadOnly(
|
||||
ctx,
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return false, errors.New("read-only check not supported for this database type")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DatabaseService) CreateReadOnlyUser(
|
||||
@@ -574,16 +627,29 @@ func (s *DatabaseService) CreateReadOnlyUser(
|
||||
usingDatabase = database
|
||||
}
|
||||
|
||||
if usingDatabase.Type != DatabaseTypePostgres {
|
||||
return "", "", errors.New("read-only user creation only supported for PostgreSQL")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
username, password, err := usingDatabase.Postgresql.CreateReadOnlyUser(
|
||||
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
|
||||
)
|
||||
var username, password string
|
||||
var err error
|
||||
|
||||
switch usingDatabase.Type {
|
||||
case DatabaseTypePostgres:
|
||||
username, password, err = usingDatabase.Postgresql.CreateReadOnlyUser(
|
||||
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMysql:
|
||||
username, password, err = usingDatabase.Mysql.CreateReadOnlyUser(
|
||||
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
|
||||
)
|
||||
case DatabaseTypeMariadb:
|
||||
username, password, err = usingDatabase.Mariadb.CreateReadOnlyUser(
|
||||
ctx, s.logger, s.fieldEncryptor, usingDatabase.ID,
|
||||
)
|
||||
default:
|
||||
return "", "", errors.New("read-only user creation not supported for this database type")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
@@ -7,4 +7,5 @@ const (
|
||||
IntervalDaily IntervalType = "DAILY"
|
||||
IntervalWeekly IntervalType = "WEEKLY"
|
||||
IntervalMonthly IntervalType = "MONTHLY"
|
||||
IntervalCron IntervalType = "CRON"
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@@ -12,11 +13,13 @@ type Interval struct {
|
||||
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
||||
Interval IntervalType `json:"interval" gorm:"type:text;not null"`
|
||||
|
||||
TimeOfDay *string `json:"timeOfDay" gorm:"type:text;"`
|
||||
TimeOfDay *string `json:"timeOfDay" gorm:"type:text;"`
|
||||
// only for WEEKLY
|
||||
Weekday *int `json:"weekday,omitempty" gorm:"type:int"`
|
||||
Weekday *int `json:"weekday,omitempty" gorm:"type:int"`
|
||||
// only for MONTHLY
|
||||
DayOfMonth *int `json:"dayOfMonth,omitempty" gorm:"type:int"`
|
||||
DayOfMonth *int `json:"dayOfMonth,omitempty" gorm:"type:int"`
|
||||
// only for CRON
|
||||
CronExpression *string `json:"cronExpression,omitempty" gorm:"type:text"`
|
||||
}
|
||||
|
||||
func (i *Interval) BeforeSave(tx *gorm.DB) error {
|
||||
@@ -40,6 +43,16 @@ func (i *Interval) Validate() error {
|
||||
return errors.New("day of month is required for monthly intervals")
|
||||
}
|
||||
|
||||
// for cron interval cron expression is required and must be valid
|
||||
if i.Interval == IntervalCron {
|
||||
if i.CronExpression == nil || *i.CronExpression == "" {
|
||||
return errors.New("cron expression is required for cron intervals")
|
||||
}
|
||||
if err := i.validateCronExpression(*i.CronExpression); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -59,6 +72,8 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
|
||||
return i.shouldTriggerWeekly(now, *lastBackupTime)
|
||||
case IntervalMonthly:
|
||||
return i.shouldTriggerMonthly(now, *lastBackupTime)
|
||||
case IntervalCron:
|
||||
return i.shouldTriggerCron(now, *lastBackupTime)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@@ -66,11 +81,12 @@ func (i *Interval) ShouldTriggerBackup(now time.Time, lastBackupTime *time.Time)
|
||||
|
||||
func (i *Interval) Copy() *Interval {
|
||||
return &Interval{
|
||||
ID: uuid.Nil,
|
||||
Interval: i.Interval,
|
||||
TimeOfDay: i.TimeOfDay,
|
||||
Weekday: i.Weekday,
|
||||
DayOfMonth: i.DayOfMonth,
|
||||
ID: uuid.Nil,
|
||||
Interval: i.Interval,
|
||||
TimeOfDay: i.TimeOfDay,
|
||||
Weekday: i.Weekday,
|
||||
DayOfMonth: i.DayOfMonth,
|
||||
CronExpression: i.CronExpression,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,3 +220,31 @@ func getStartOfWeek(t time.Time) time.Time {
|
||||
func getStartOfMonth(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
|
||||
// cron trigger: check if we've passed a scheduled cron time since last backup
|
||||
func (i *Interval) shouldTriggerCron(now, lastBackup time.Time) bool {
|
||||
if i.CronExpression == nil || *i.CronExpression == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
schedule, err := parser.Parse(*i.CronExpression)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Find the next scheduled time after the last backup
|
||||
nextAfterLastBackup := schedule.Next(lastBackup)
|
||||
|
||||
// If we're at or past that next scheduled time, trigger
|
||||
return now.After(nextAfterLastBackup) || now.Equal(nextAfterLastBackup)
|
||||
}
|
||||
|
||||
func (i *Interval) validateCronExpression(expr string) error {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
_, err := parser.Parse(expr)
|
||||
if err != nil {
|
||||
return errors.New("invalid cron expression: " + err.Error())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -457,6 +457,144 @@ func TestInterval_ShouldTriggerBackup_Monthly(t *testing.T) {
|
||||
)
|
||||
}
|
||||
|
||||
func TestInterval_ShouldTriggerBackup_Cron(t *testing.T) {
|
||||
cronExpr := "0 2 * * *" // Daily at 2:00 AM
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &cronExpr,
|
||||
}
|
||||
|
||||
t.Run("No previous backup: Trigger backup immediately", func(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
should := interval.ShouldTriggerBackup(now, nil)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("Before scheduled cron time: Do not trigger backup", func(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 1, 59, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC) // Yesterday at 2 AM
|
||||
should := interval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.False(t, should)
|
||||
})
|
||||
|
||||
t.Run("Exactly at scheduled cron time: Trigger backup", func(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 2, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC) // Yesterday at 2 AM
|
||||
should := interval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("After scheduled cron time: Trigger backup", func(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 3, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 2, 0, 0, 0, time.UTC) // Yesterday at 2 AM
|
||||
should := interval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("Backup already done after scheduled time: Do not trigger again", func(t *testing.T) {
|
||||
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 15, 2, 5, 0, 0, time.UTC) // Today at 2:05 AM
|
||||
should := interval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.False(t, should)
|
||||
})
|
||||
|
||||
t.Run("Weekly cron expression: 0 3 * * 1 (Monday at 3 AM)", func(t *testing.T) {
|
||||
weeklyCron := "0 3 * * 1" // Every Monday at 3 AM
|
||||
weeklyInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &weeklyCron,
|
||||
}
|
||||
|
||||
// Monday Jan 15, 2024 at 3:00 AM
|
||||
monday := time.Date(2024, 1, 15, 3, 0, 0, 0, time.UTC)
|
||||
// Last backup was previous Monday
|
||||
lastBackup := time.Date(2024, 1, 8, 3, 0, 0, 0, time.UTC)
|
||||
|
||||
should := weeklyInterval.ShouldTriggerBackup(monday, &lastBackup)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("Complex cron expression: 30 4 1,15 * * (1st and 15th at 4:30 AM)", func(t *testing.T) {
|
||||
complexCron := "30 4 1,15 * *" // 1st and 15th of each month at 4:30 AM
|
||||
complexInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &complexCron,
|
||||
}
|
||||
|
||||
// Jan 15, 2024 at 4:30 AM
|
||||
now := time.Date(2024, 1, 15, 4, 30, 0, 0, time.UTC)
|
||||
// Last backup was Jan 1
|
||||
lastBackup := time.Date(2024, 1, 1, 4, 30, 0, 0, time.UTC)
|
||||
|
||||
should := complexInterval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("Every 6 hours cron expression: 0 */6 * * *", func(t *testing.T) {
|
||||
sixHourlyCron := "0 */6 * * *" // Every 6 hours (0:00, 6:00, 12:00, 18:00)
|
||||
sixHourlyInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &sixHourlyCron,
|
||||
}
|
||||
|
||||
// 12:00 - next trigger after 6:00
|
||||
now := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC)
|
||||
// Last backup was at 6:00
|
||||
lastBackup := time.Date(2024, 1, 15, 6, 0, 0, 0, time.UTC)
|
||||
|
||||
should := sixHourlyInterval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.True(t, should)
|
||||
})
|
||||
|
||||
t.Run("Invalid cron expression returns false", func(t *testing.T) {
|
||||
invalidCron := "invalid cron"
|
||||
invalidInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &invalidCron,
|
||||
}
|
||||
|
||||
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||
|
||||
should := invalidInterval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.False(t, should)
|
||||
})
|
||||
|
||||
t.Run("Empty cron expression returns false", func(t *testing.T) {
|
||||
emptyCron := ""
|
||||
emptyInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &emptyCron,
|
||||
}
|
||||
|
||||
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||
|
||||
should := emptyInterval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.False(t, should)
|
||||
})
|
||||
|
||||
t.Run("Nil cron expression returns false", func(t *testing.T) {
|
||||
nilInterval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: nil,
|
||||
}
|
||||
|
||||
now := time.Date(2024, 1, 15, 10, 0, 0, 0, time.UTC)
|
||||
lastBackup := time.Date(2024, 1, 14, 10, 0, 0, 0, time.UTC)
|
||||
|
||||
should := nilInterval.ShouldTriggerBackup(now, &lastBackup)
|
||||
assert.False(t, should)
|
||||
})
|
||||
}
|
||||
|
||||
func TestInterval_Validate(t *testing.T) {
|
||||
t.Run("Daily interval requires time of day", func(t *testing.T) {
|
||||
interval := &Interval{
|
||||
@@ -526,4 +664,60 @@ func TestInterval_Validate(t *testing.T) {
|
||||
err := interval.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Cron interval requires cron expression", func(t *testing.T) {
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
}
|
||||
err := interval.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cron expression is required")
|
||||
})
|
||||
|
||||
t.Run("Cron interval with empty expression is invalid", func(t *testing.T) {
|
||||
emptyCron := ""
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &emptyCron,
|
||||
}
|
||||
err := interval.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cron expression is required")
|
||||
})
|
||||
|
||||
t.Run("Cron interval with invalid expression is invalid", func(t *testing.T) {
|
||||
invalidCron := "invalid cron"
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &invalidCron,
|
||||
}
|
||||
err := interval.Validate()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid cron expression")
|
||||
})
|
||||
|
||||
t.Run("Valid cron interval with daily expression", func(t *testing.T) {
|
||||
cronExpr := "0 2 * * *" // Daily at 2 AM
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &cronExpr,
|
||||
}
|
||||
err := interval.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("Valid cron interval with complex expression", func(t *testing.T) {
|
||||
cronExpr := "30 4 1,15 * *" // 1st and 15th of each month at 4:30 AM
|
||||
interval := &Interval{
|
||||
ID: uuid.New(),
|
||||
Interval: IntervalCron,
|
||||
CronExpression: &cronExpr,
|
||||
}
|
||||
err := interval.Validate()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -206,8 +206,8 @@ func (t *WebhookNotifier) sendPOST(webhookURL, heading, message string, logger *
|
||||
func (t *WebhookNotifier) buildRequestBody(heading, message string) []byte {
|
||||
if t.BodyTemplate != nil && *t.BodyTemplate != "" {
|
||||
result := *t.BodyTemplate
|
||||
result = strings.ReplaceAll(result, "{{heading}}", heading)
|
||||
result = strings.ReplaceAll(result, "{{message}}", message)
|
||||
result = strings.ReplaceAll(result, "{{heading}}", escapeJSONString(heading))
|
||||
result = strings.ReplaceAll(result, "{{message}}", escapeJSONString(message))
|
||||
return []byte(result)
|
||||
}
|
||||
|
||||
@@ -227,3 +227,17 @@ func (t *WebhookNotifier) applyHeaders(req *http.Request) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func escapeJSONString(s string) string {
|
||||
b, err := json.Marshal(s)
|
||||
if err != nil || len(b) < 2 {
|
||||
escaped := strings.ReplaceAll(s, `\`, `\\`)
|
||||
escaped = strings.ReplaceAll(escaped, `"`, `\"`)
|
||||
escaped = strings.ReplaceAll(escaped, "\n", `\n`)
|
||||
escaped = strings.ReplaceAll(escaped, "\r", `\r`)
|
||||
escaped = strings.ReplaceAll(escaped, "\t", `\t`)
|
||||
return escaped
|
||||
}
|
||||
|
||||
return string(b[1 : len(b)-1])
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -170,6 +171,36 @@ func Test_RestoreBackup_WhenUserIsNotWorkspaceMember_ReturnsForbidden(t *testing
|
||||
assert.Contains(t, string(testResp.Body), "insufficient permissions")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_WithIsExcludeExtensions_FlagPassedCorrectly(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("Test Workspace", owner, router)
|
||||
|
||||
_, backup := createTestDatabaseWithBackupForRestore(workspace, owner, router)
|
||||
|
||||
request := RestoreBackupRequest{
|
||||
PostgresqlDatabase: &postgresql.PostgresqlDatabase{
|
||||
Version: tools.PostgresqlVersion16,
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Username: "postgres",
|
||||
Password: "postgres",
|
||||
IsExcludeExtensions: true,
|
||||
},
|
||||
}
|
||||
|
||||
testResp := test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backup.ID.String()),
|
||||
"Bearer "+owner.Token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
|
||||
assert.Contains(t, string(testResp.Body), "restore started successfully")
|
||||
}
|
||||
|
||||
func Test_RestoreBackup_AuditLogWritten(t *testing.T) {
|
||||
router := createTestRouter()
|
||||
owner := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
@@ -340,7 +371,7 @@ func createTestBackup(
|
||||
dummyContent := []byte("dummy backup content for testing")
|
||||
reader := strings.NewReader(string(dummyContent))
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
if err := storages[0].SaveFile(fieldEncryptor, logger, backup.ID, reader); err != nil {
|
||||
if err := storages[0].SaveFile(context.Background(), fieldEncryptor, logger, backup.ID, reader); err != nil {
|
||||
panic(fmt.Sprintf("Failed to create test backup file: %v", err))
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"postgresus-backend/internal/features/restores/usecases"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
@@ -22,6 +23,7 @@ var restoreService = &RestoreService{
|
||||
logger.GetLogger(),
|
||||
workspaces_services.GetWorkspaceService(),
|
||||
audit_logs.GetAuditLogService(),
|
||||
encryption.GetFieldEncryptor(),
|
||||
}
|
||||
var restoreController = &RestoreController{
|
||||
restoreService,
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
package restores
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/databases/databases/mysql"
|
||||
"postgresus-backend/internal/features/databases/databases/postgresql"
|
||||
)
|
||||
|
||||
type RestoreBackupRequest struct {
|
||||
PostgresqlDatabase *postgresql.PostgresqlDatabase `json:"postgresqlDatabase"`
|
||||
MysqlDatabase *mysql.MysqlDatabase `json:"mysqlDatabase"`
|
||||
MariadbDatabase *mariadb.MariadbDatabase `json:"mariadbDatabase"`
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_models "postgresus-backend/internal/features/users/models"
|
||||
workspaces_services "postgresus-backend/internal/features/workspaces/services"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
"time"
|
||||
|
||||
@@ -30,6 +31,7 @@ type RestoreService struct {
|
||||
logger *slog.Logger
|
||||
workspaceService *workspaces_services.WorkspaceService
|
||||
auditLogService *audit_logs.AuditLogService
|
||||
fieldEncryptor encryption.FieldEncryptor
|
||||
}
|
||||
|
||||
func (s *RestoreService) OnBeforeBackupRemove(backup *backups.Backup) error {
|
||||
@@ -120,19 +122,8 @@ func (s *RestoreService) RestoreBackupWithAuth(
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf(
|
||||
"restore from %s to %s\n",
|
||||
backupDatabase.Postgresql.Version,
|
||||
requestDTO.PostgresqlDatabase.Version,
|
||||
)
|
||||
|
||||
if tools.IsBackupDbVersionHigherThanRestoreDbVersion(
|
||||
backupDatabase.Postgresql.Version,
|
||||
requestDTO.PostgresqlDatabase.Version,
|
||||
) {
|
||||
return errors.New(`backup database version is higher than restore database version. ` +
|
||||
`Should be restored to the same version as the backup database or higher. ` +
|
||||
`For example, you can restore PG 15 backup to PG 15, 16 or higher. But cannot restore to 14 and lower`)
|
||||
if err := s.validateVersionCompatibility(backupDatabase, requestDTO); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -167,10 +158,19 @@ func (s *RestoreService) RestoreBackup(
|
||||
return err
|
||||
}
|
||||
|
||||
if database.Type == databases.DatabaseTypePostgres {
|
||||
switch database.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
if requestDTO.PostgresqlDatabase == nil {
|
||||
return errors.New("postgresql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMysql:
|
||||
if requestDTO.MysqlDatabase == nil {
|
||||
return errors.New("mysql database is required")
|
||||
}
|
||||
case databases.DatabaseTypeMariadb:
|
||||
if requestDTO.MariadbDatabase == nil {
|
||||
return errors.New("mariadb database is required")
|
||||
}
|
||||
}
|
||||
|
||||
restore := models.Restore{
|
||||
@@ -211,7 +211,19 @@ func (s *RestoreService) RestoreBackup(
|
||||
start := time.Now().UTC()
|
||||
|
||||
restoringToDB := &databases.Database{
|
||||
Type: database.Type,
|
||||
Postgresql: requestDTO.PostgresqlDatabase,
|
||||
Mysql: requestDTO.MysqlDatabase,
|
||||
Mariadb: requestDTO.MariadbDatabase,
|
||||
}
|
||||
|
||||
if err := restoringToDB.PopulateVersionIfEmpty(s.logger, s.fieldEncryptor); err != nil {
|
||||
return fmt.Errorf("failed to auto-detect database version: %w", err)
|
||||
}
|
||||
|
||||
isExcludeExtensions := false
|
||||
if requestDTO.PostgresqlDatabase != nil {
|
||||
isExcludeExtensions = requestDTO.PostgresqlDatabase.IsExcludeExtensions
|
||||
}
|
||||
|
||||
err = s.restoreBackupUsecase.Execute(
|
||||
@@ -221,6 +233,7 @@ func (s *RestoreService) RestoreBackup(
|
||||
restoringToDB,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
if err != nil {
|
||||
errMsg := err.Error()
|
||||
@@ -244,3 +257,80 @@ func (s *RestoreService) RestoreBackup(
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RestoreService) validateVersionCompatibility(
|
||||
backupDatabase *databases.Database,
|
||||
requestDTO RestoreBackupRequest,
|
||||
) error {
|
||||
// populate version
|
||||
if requestDTO.MariadbDatabase != nil {
|
||||
err := requestDTO.MariadbDatabase.PopulateVersion(
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
backupDatabase.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if requestDTO.MysqlDatabase != nil {
|
||||
err := requestDTO.MysqlDatabase.PopulateVersion(
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
backupDatabase.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if requestDTO.PostgresqlDatabase != nil {
|
||||
err := requestDTO.PostgresqlDatabase.PopulateVersion(
|
||||
s.logger,
|
||||
s.fieldEncryptor,
|
||||
backupDatabase.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch backupDatabase.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
if requestDTO.PostgresqlDatabase == nil {
|
||||
return errors.New("postgresql database configuration is required for restore")
|
||||
}
|
||||
if tools.IsBackupDbVersionHigherThanRestoreDbVersion(
|
||||
backupDatabase.Postgresql.Version,
|
||||
requestDTO.PostgresqlDatabase.Version,
|
||||
) {
|
||||
return errors.New(`backup database version is higher than restore database version. ` +
|
||||
`Should be restored to the same version as the backup database or higher. ` +
|
||||
`For example, you can restore PG 15 backup to PG 15, 16 or higher. But cannot restore to 14 and lower`)
|
||||
}
|
||||
case databases.DatabaseTypeMysql:
|
||||
if requestDTO.MysqlDatabase == nil {
|
||||
return errors.New("mysql database configuration is required for restore")
|
||||
}
|
||||
if tools.IsMysqlBackupVersionHigherThanRestoreVersion(
|
||||
backupDatabase.Mysql.Version,
|
||||
requestDTO.MysqlDatabase.Version,
|
||||
) {
|
||||
return errors.New(`backup database version is higher than restore database version. ` +
|
||||
`Should be restored to the same version as the backup database or higher. ` +
|
||||
`For example, you can restore MySQL 8.0 backup to MySQL 8.0, 8.4 or higher. But cannot restore to 5.7`)
|
||||
}
|
||||
case databases.DatabaseTypeMariadb:
|
||||
if requestDTO.MariadbDatabase == nil {
|
||||
return errors.New("mariadb database configuration is required for restore")
|
||||
}
|
||||
if tools.IsMariadbBackupVersionHigherThanRestoreVersion(
|
||||
backupDatabase.Mariadb.Version,
|
||||
requestDTO.MariadbDatabase.Version,
|
||||
) {
|
||||
return errors.New(`backup database version is higher than restore database version. ` +
|
||||
`Should be restored to the same version as the backup database or higher. ` +
|
||||
`For example, you can restore MariaDB 10.11 backup to MariaDB 10.11, 11.4 or higher. But cannot restore to 10.6`)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package usecases
|
||||
|
||||
import (
|
||||
usecases_mariadb "postgresus-backend/internal/features/restores/usecases/mariadb"
|
||||
usecases_mysql "postgresus-backend/internal/features/restores/usecases/mysql"
|
||||
usecases_postgresql "postgresus-backend/internal/features/restores/usecases/postgresql"
|
||||
)
|
||||
|
||||
var restoreBackupUsecase = &RestoreBackupUsecase{
|
||||
usecases_postgresql.GetRestorePostgresqlBackupUsecase(),
|
||||
usecases_mysql.GetRestoreMysqlBackupUsecase(),
|
||||
usecases_mariadb.GetRestoreMariadbBackupUsecase(),
|
||||
}
|
||||
|
||||
func GetRestoreBackupUsecase() *RestoreBackupUsecase {
|
||||
|
||||
15
backend/internal/features/restores/usecases/mariadb/di.go
Normal file
15
backend/internal/features/restores/usecases/mariadb/di.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package usecases_mariadb
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreMariadbBackupUsecase = &RestoreMariadbBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
secrets.GetSecretKeyService(),
|
||||
}
|
||||
|
||||
func GetRestoreMariadbBackupUsecase() *RestoreMariadbBackupUsecase {
|
||||
return restoreMariadbBackupUsecase
|
||||
}
|
||||
@@ -0,0 +1,473 @@
|
||||
package usecases_mariadb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mariadbtypes "postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
type RestoreMariadbBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) Execute(
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMariadb {
|
||||
return errors.New("database type not supported")
|
||||
}
|
||||
|
||||
uc.logger.Info(
|
||||
"Restoring MariaDB backup via mariadb client",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
)
|
||||
|
||||
mdb := restoringToDB.Mariadb
|
||||
if mdb == nil {
|
||||
return fmt.Errorf("mariadb configuration is required for restore")
|
||||
}
|
||||
|
||||
if mdb.Database == nil || *mdb.Database == "" {
|
||||
return fmt.Errorf("target database name is required for mariadb restore")
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--host=" + mdb.Host,
|
||||
"--port=" + strconv.Itoa(mdb.Port),
|
||||
"--user=" + mdb.Username,
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if mdb.IsHttps {
|
||||
args = append(args, "--ssl")
|
||||
}
|
||||
|
||||
if mdb.Database != nil && *mdb.Database != "" {
|
||||
args = append(args, *mdb.Database)
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
originalDB,
|
||||
tools.GetMariadbExecutable(
|
||||
tools.MariadbExecutableMariadb,
|
||||
mdb.Version,
|
||||
config.GetEnv().EnvMode,
|
||||
config.GetEnv().MariadbInstallDir,
|
||||
),
|
||||
args,
|
||||
mdb.Password,
|
||||
backup,
|
||||
storage,
|
||||
mdb,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) restoreFromStorage(
|
||||
database *databases.Database,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
decryptedPassword, err := fieldEncryptor.Decrypt(database.ID, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile, err := uc.createTempMyCnfFile(mdbConfig, decryptedPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create .my.cnf: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download backup: %w", err)
|
||||
}
|
||||
defer cleanupFunc()
|
||||
|
||||
return uc.executeMariadbRestore(
|
||||
ctx,
|
||||
database,
|
||||
mariadbBin,
|
||||
args,
|
||||
myCnfFile,
|
||||
tempBackupFile,
|
||||
backup,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) executeMariadbRestore(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
mariadbBin string,
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupFile string,
|
||||
backup *backups.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, mariadbBin, fullArgs...)
|
||||
uc.logger.Info("Executing MariaDB restore command", "command", cmd.String())
|
||||
|
||||
backupFileHandle, err := os.Open(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open backup file: %w", err)
|
||||
}
|
||||
defer func() { _ = backupFileHandle.Close() }()
|
||||
|
||||
var inputReader io.Reader = backupFileHandle
|
||||
|
||||
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup decryption: %w", err)
|
||||
}
|
||||
inputReader = decryptReader
|
||||
}
|
||||
|
||||
zstdReader, err := zstd.NewReader(inputReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create zstd reader: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
cmd.Stdin = zstdReader
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env,
|
||||
"MYSQL_PWD=",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
)
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrCh := make(chan []byte, 1)
|
||||
go func() {
|
||||
output, _ := io.ReadAll(stderrPipe)
|
||||
stderrCh <- output
|
||||
}()
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start mariadb: %w", err)
|
||||
}
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
return uc.handleMariadbRestoreError(database, waitErr, stderrOutput, mariadbBin)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) downloadBackupToTempFile(
|
||||
ctx context.Context,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) (string, func(), error) {
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
cleanupFunc := func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
|
||||
|
||||
uc.logger.Info(
|
||||
"Downloading backup file from storage to temporary file",
|
||||
"backupId", backup.ID,
|
||||
"tempFile", tempBackupFile,
|
||||
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
|
||||
)
|
||||
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rawReader.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close backup reader", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tempFile, err := os.Create(tempBackupFile)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := tempFile.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close temporary file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
|
||||
}
|
||||
|
||||
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
|
||||
return tempBackupFile, cleanupFunc, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get master key for decryption: %w", err)
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encryption salt: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encryption IV: %w", err)
|
||||
}
|
||||
|
||||
decryptReader, err := encryption.NewDecryptionReader(
|
||||
reader,
|
||||
masterKey,
|
||||
backup.ID,
|
||||
salt,
|
||||
iv,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create decryption reader: %w", err)
|
||||
}
|
||||
|
||||
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
|
||||
return decryptReader, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) createTempMyCnfFile(
|
||||
mdbConfig *mariadbtypes.MariadbDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp("", "mycnf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
user=%s
|
||||
password="%s"
|
||||
host=%s
|
||||
port=%d
|
||||
`, mdbConfig.Username, tools.EscapeMariadbPassword(password), mdbConfig.Host, mdbConfig.Port)
|
||||
|
||||
if mdbConfig.IsHttps {
|
||||
content += "ssl=true\n"
|
||||
} else {
|
||||
content += "ssl=false\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
return myCnfFile, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
) (int64, error) {
|
||||
buf := make([]byte, 16*1024*1024)
|
||||
var totalBytesWritten int64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
|
||||
}
|
||||
|
||||
bytesRead, readErr := src.Read(buf)
|
||||
if bytesRead > 0 {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
if bytesWritten < 0 || bytesRead < bytesWritten {
|
||||
bytesWritten = 0
|
||||
if writeErr == nil {
|
||||
writeErr = fmt.Errorf("invalid write result")
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return totalBytesWritten, writeErr
|
||||
}
|
||||
|
||||
if bytesRead != bytesWritten {
|
||||
return totalBytesWritten, io.ErrShortWrite
|
||||
}
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMariadbBackupUsecase) handleMariadbRestoreError(
|
||||
database *databases.Database,
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
mariadbBin string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(mariadbBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
if containsIgnoreCase(stderrStr, "access denied") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB access denied. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "can't connect") ||
|
||||
containsIgnoreCase(stderrStr, "connection refused") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB connection refused. Check if the server is running and accessible. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "unknown database") {
|
||||
backupDbName := "unknown"
|
||||
if database.Mariadb != nil && database.Mariadb.Database != nil {
|
||||
backupDbName = *database.Mariadb.Database
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"target database does not exist (backup db %s). Create the database before restoring. stderr: %s",
|
||||
backupDbName,
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB SSL connection failed. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"MariaDB connection timeout. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
15
backend/internal/features/restores/usecases/mysql/di.go
Normal file
15
backend/internal/features/restores/usecases/mysql/di.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package usecases_mysql
|
||||
|
||||
import (
|
||||
"postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
)
|
||||
|
||||
var restoreMysqlBackupUsecase = &RestoreMysqlBackupUsecase{
|
||||
logger.GetLogger(),
|
||||
secrets.GetSecretKeyService(),
|
||||
}
|
||||
|
||||
func GetRestoreMysqlBackupUsecase() *RestoreMysqlBackupUsecase {
|
||||
return restoreMysqlBackupUsecase
|
||||
}
|
||||
@@ -0,0 +1,463 @@
|
||||
package usecases_mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/klauspost/compress/zstd"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
"postgresus-backend/internal/features/backups/backups/encryption"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mysqltypes "postgresus-backend/internal/features/databases/databases/mysql"
|
||||
encryption_secrets "postgresus-backend/internal/features/encryption/secrets"
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
util_encryption "postgresus-backend/internal/util/encryption"
|
||||
files_utils "postgresus-backend/internal/util/files"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
type RestoreMysqlBackupUsecase struct {
|
||||
logger *slog.Logger
|
||||
secretKeyService *encryption_secrets.SecretKeyService
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) Execute(
|
||||
originalDB *databases.Database,
|
||||
restoringToDB *databases.Database,
|
||||
backupConfig *backups_config.BackupConfig,
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypeMysql {
|
||||
return errors.New("database type not supported")
|
||||
}
|
||||
|
||||
uc.logger.Info(
|
||||
"Restoring MySQL backup via mysql client",
|
||||
"restoreId", restore.ID,
|
||||
"backupId", backup.ID,
|
||||
)
|
||||
|
||||
my := restoringToDB.Mysql
|
||||
if my == nil {
|
||||
return fmt.Errorf("mysql configuration is required for restore")
|
||||
}
|
||||
|
||||
if my.Database == nil || *my.Database == "" {
|
||||
return fmt.Errorf("target database name is required for mysql restore")
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"--host=" + my.Host,
|
||||
"--port=" + strconv.Itoa(my.Port),
|
||||
"--user=" + my.Username,
|
||||
"--verbose",
|
||||
}
|
||||
|
||||
if my.IsHttps {
|
||||
args = append(args, "--ssl-mode=REQUIRED")
|
||||
}
|
||||
|
||||
if my.Database != nil && *my.Database != "" {
|
||||
args = append(args, *my.Database)
|
||||
}
|
||||
|
||||
return uc.restoreFromStorage(
|
||||
originalDB,
|
||||
tools.GetMysqlExecutable(
|
||||
my.Version,
|
||||
tools.MysqlExecutableMysql,
|
||||
config.GetEnv().EnvMode,
|
||||
config.GetEnv().MysqlInstallDir,
|
||||
),
|
||||
args,
|
||||
my.Password,
|
||||
backup,
|
||||
storage,
|
||||
my,
|
||||
)
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) restoreFromStorage(
|
||||
database *databases.Database,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
password string,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(1 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if config.IsShouldShutdown() {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
decryptedPassword, err := fieldEncryptor.Decrypt(database.ID, password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt password: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile, err := uc.createTempMyCnfFile(myConfig, decryptedPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create .my.cnf: %w", err)
|
||||
}
|
||||
defer func() { _ = os.RemoveAll(filepath.Dir(myCnfFile)) }()
|
||||
|
||||
tempBackupFile, cleanupFunc, err := uc.downloadBackupToTempFile(ctx, backup, storage)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download backup: %w", err)
|
||||
}
|
||||
defer cleanupFunc()
|
||||
|
||||
return uc.executeMysqlRestore(ctx, database, mysqlBin, args, myCnfFile, tempBackupFile, backup)
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) executeMysqlRestore(
|
||||
ctx context.Context,
|
||||
database *databases.Database,
|
||||
mysqlBin string,
|
||||
args []string,
|
||||
myCnfFile string,
|
||||
backupFile string,
|
||||
backup *backups.Backup,
|
||||
) error {
|
||||
fullArgs := append([]string{"--defaults-file=" + myCnfFile}, args...)
|
||||
|
||||
cmd := exec.CommandContext(ctx, mysqlBin, fullArgs...)
|
||||
uc.logger.Info("Executing MySQL restore command", "command", cmd.String())
|
||||
|
||||
backupFileHandle, err := os.Open(backupFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to open backup file: %w", err)
|
||||
}
|
||||
defer func() { _ = backupFileHandle.Close() }()
|
||||
|
||||
var inputReader io.Reader = backupFileHandle
|
||||
|
||||
if backup.Encryption == backups_config.BackupEncryptionEncrypted {
|
||||
decryptReader, err := uc.setupDecryption(backupFileHandle, backup)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup decryption: %w", err)
|
||||
}
|
||||
inputReader = decryptReader
|
||||
}
|
||||
|
||||
zstdReader, err := zstd.NewReader(inputReader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create zstd reader: %w", err)
|
||||
}
|
||||
defer zstdReader.Close()
|
||||
|
||||
cmd.Stdin = zstdReader
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Env = append(cmd.Env,
|
||||
"MYSQL_PWD=",
|
||||
"LC_ALL=C.UTF-8",
|
||||
"LANG=C.UTF-8",
|
||||
)
|
||||
|
||||
stderrPipe, err := cmd.StderrPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
stderrCh := make(chan []byte, 1)
|
||||
go func() {
|
||||
output, _ := io.ReadAll(stderrPipe)
|
||||
stderrCh <- output
|
||||
}()
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start mysql: %w", err)
|
||||
}
|
||||
|
||||
waitErr := cmd.Wait()
|
||||
stderrOutput := <-stderrCh
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return fmt.Errorf("restore cancelled due to shutdown")
|
||||
}
|
||||
|
||||
if waitErr != nil {
|
||||
return uc.handleMysqlRestoreError(database, waitErr, stderrOutput, mysqlBin)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) downloadBackupToTempFile(
|
||||
ctx context.Context,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
) (string, func(), error) {
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
config.GetEnv().TempFolder,
|
||||
})
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to ensure directories: %w", err)
|
||||
}
|
||||
|
||||
tempDir, err := os.MkdirTemp(config.GetEnv().TempFolder, "restore_"+uuid.New().String())
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("failed to create temporary directory: %w", err)
|
||||
}
|
||||
|
||||
cleanupFunc := func() {
|
||||
_ = os.RemoveAll(tempDir)
|
||||
}
|
||||
|
||||
tempBackupFile := filepath.Join(tempDir, "backup.sql.zst")
|
||||
|
||||
uc.logger.Info(
|
||||
"Downloading backup file from storage to temporary file",
|
||||
"backupId", backup.ID,
|
||||
"tempFile", tempBackupFile,
|
||||
"encrypted", backup.Encryption == backups_config.BackupEncryptionEncrypted,
|
||||
)
|
||||
|
||||
fieldEncryptor := util_encryption.GetFieldEncryptor()
|
||||
rawReader, err := storage.GetFile(fieldEncryptor, backup.ID)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to get backup file from storage: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := rawReader.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close backup reader", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tempFile, err := os.Create(tempBackupFile)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to create temporary backup file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := tempFile.Close(); err != nil {
|
||||
uc.logger.Error("Failed to close temporary file", "error", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = uc.copyWithShutdownCheck(ctx, tempFile, rawReader)
|
||||
if err != nil {
|
||||
cleanupFunc()
|
||||
return "", nil, fmt.Errorf("failed to write backup to temporary file: %w", err)
|
||||
}
|
||||
|
||||
uc.logger.Info("Backup file written to temporary location", "tempFile", tempBackupFile)
|
||||
return tempBackupFile, cleanupFunc, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) setupDecryption(
|
||||
reader io.Reader,
|
||||
backup *backups.Backup,
|
||||
) (io.Reader, error) {
|
||||
if backup.EncryptionSalt == nil || backup.EncryptionIV == nil {
|
||||
return nil, fmt.Errorf("backup is encrypted but missing encryption metadata")
|
||||
}
|
||||
|
||||
masterKey, err := uc.secretKeyService.GetSecretKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get master key for decryption: %w", err)
|
||||
}
|
||||
|
||||
salt, err := base64.StdEncoding.DecodeString(*backup.EncryptionSalt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encryption salt: %w", err)
|
||||
}
|
||||
|
||||
iv, err := base64.StdEncoding.DecodeString(*backup.EncryptionIV)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decode encryption IV: %w", err)
|
||||
}
|
||||
|
||||
decryptReader, err := encryption.NewDecryptionReader(
|
||||
reader,
|
||||
masterKey,
|
||||
backup.ID,
|
||||
salt,
|
||||
iv,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create decryption reader: %w", err)
|
||||
}
|
||||
|
||||
uc.logger.Info("Using decryption for encrypted backup", "backupId", backup.ID)
|
||||
return decryptReader, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) createTempMyCnfFile(
|
||||
myConfig *mysqltypes.MysqlDatabase,
|
||||
password string,
|
||||
) (string, error) {
|
||||
tempDir, err := os.MkdirTemp("", "mycnf")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
||||
}
|
||||
|
||||
myCnfFile := filepath.Join(tempDir, ".my.cnf")
|
||||
|
||||
content := fmt.Sprintf(`[client]
|
||||
user=%s
|
||||
password="%s"
|
||||
host=%s
|
||||
port=%d
|
||||
`, myConfig.Username, tools.EscapeMysqlPassword(password), myConfig.Host, myConfig.Port)
|
||||
|
||||
if myConfig.IsHttps {
|
||||
content += "ssl-mode=REQUIRED\n"
|
||||
}
|
||||
|
||||
err = os.WriteFile(myCnfFile, []byte(content), 0600)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write .my.cnf: %w", err)
|
||||
}
|
||||
|
||||
return myCnfFile, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) copyWithShutdownCheck(
|
||||
ctx context.Context,
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
) (int64, error) {
|
||||
buf := make([]byte, 16*1024*1024)
|
||||
var totalBytesWritten int64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
if config.IsShouldShutdown() {
|
||||
return totalBytesWritten, fmt.Errorf("copy cancelled due to shutdown")
|
||||
}
|
||||
|
||||
bytesRead, readErr := src.Read(buf)
|
||||
if bytesRead > 0 {
|
||||
bytesWritten, writeErr := dst.Write(buf[0:bytesRead])
|
||||
if bytesWritten < 0 || bytesRead < bytesWritten {
|
||||
bytesWritten = 0
|
||||
if writeErr == nil {
|
||||
writeErr = fmt.Errorf("invalid write result")
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return totalBytesWritten, writeErr
|
||||
}
|
||||
|
||||
if bytesRead != bytesWritten {
|
||||
return totalBytesWritten, io.ErrShortWrite
|
||||
}
|
||||
|
||||
totalBytesWritten += int64(bytesWritten)
|
||||
}
|
||||
|
||||
if readErr != nil {
|
||||
if readErr != io.EOF {
|
||||
return totalBytesWritten, readErr
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return totalBytesWritten, nil
|
||||
}
|
||||
|
||||
func (uc *RestoreMysqlBackupUsecase) handleMysqlRestoreError(
|
||||
database *databases.Database,
|
||||
waitErr error,
|
||||
stderrOutput []byte,
|
||||
mysqlBin string,
|
||||
) error {
|
||||
stderrStr := string(stderrOutput)
|
||||
errorMsg := fmt.Sprintf(
|
||||
"%s failed: %v – stderr: %s",
|
||||
filepath.Base(mysqlBin),
|
||||
waitErr,
|
||||
stderrStr,
|
||||
)
|
||||
|
||||
if containsIgnoreCase(stderrStr, "access denied") {
|
||||
return fmt.Errorf(
|
||||
"MySQL access denied. Check username and password. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "can't connect") ||
|
||||
containsIgnoreCase(stderrStr, "connection refused") {
|
||||
return fmt.Errorf(
|
||||
"MySQL connection refused. Check if the server is running and accessible. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "unknown database") {
|
||||
backupDbName := "unknown"
|
||||
if database.Mysql != nil && database.Mysql.Database != nil {
|
||||
backupDbName = *database.Mysql.Database
|
||||
}
|
||||
|
||||
return fmt.Errorf(
|
||||
"target database does not exist (backup db %s). Create the database before restoring. stderr: %s",
|
||||
backupDbName,
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "ssl") {
|
||||
return fmt.Errorf(
|
||||
"MySQL SSL connection failed. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
if containsIgnoreCase(stderrStr, "timeout") {
|
||||
return fmt.Errorf(
|
||||
"MySQL connection timeout. stderr: %s",
|
||||
stderrStr,
|
||||
)
|
||||
}
|
||||
|
||||
return errors.New(errorMsg)
|
||||
}
|
||||
|
||||
func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
@@ -42,6 +42,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
restore models.Restore,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
if originalDB.Type != databases.DatabaseTypePostgres {
|
||||
return errors.New("database type not supported")
|
||||
@@ -96,6 +97,7 @@ func (uc *RestorePostgresqlBackupUsecase) Execute(
|
||||
backup,
|
||||
storage,
|
||||
pg,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -108,6 +110,7 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
uc.logger.Info(
|
||||
"Restoring PostgreSQL backup from storage via temporary file",
|
||||
@@ -115,6 +118,8 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
pgBin,
|
||||
"args",
|
||||
args,
|
||||
"isExcludeExtensions",
|
||||
isExcludeExtensions,
|
||||
)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
|
||||
@@ -171,6 +176,26 @@ func (uc *RestorePostgresqlBackupUsecase) restoreFromStorage(
|
||||
}
|
||||
defer cleanupFunc()
|
||||
|
||||
// If excluding extensions, generate filtered TOC list and use it
|
||||
if isExcludeExtensions {
|
||||
tocListFile, err := uc.generateFilteredTocList(
|
||||
ctx,
|
||||
pgBin,
|
||||
tempBackupFile,
|
||||
pgpassFile,
|
||||
pgConfig,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate filtered TOC list: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = os.Remove(tocListFile)
|
||||
}()
|
||||
|
||||
// Add -L flag to use the filtered list
|
||||
args = append(args, "-L", tocListFile)
|
||||
}
|
||||
|
||||
// Add the temporary backup file as the last argument to pg_restore
|
||||
args = append(args, tempBackupFile)
|
||||
|
||||
@@ -502,7 +527,7 @@ func (uc *RestorePostgresqlBackupUsecase) copyWithShutdownCheck(
|
||||
dst io.Writer,
|
||||
src io.Reader,
|
||||
) (int64, error) {
|
||||
buf := make([]byte, 32*1024) // 32KB buffer
|
||||
buf := make([]byte, 16*1024*1024) // 16MB buffer
|
||||
var totalBytesWritten int64
|
||||
|
||||
for {
|
||||
@@ -554,6 +579,75 @@ func containsIgnoreCase(str, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(str), strings.ToLower(substr))
|
||||
}
|
||||
|
||||
// generateFilteredTocList generates a pg_restore TOC list file with extensions filtered out.
|
||||
// This is used when isExcludeExtensions is true to skip CREATE EXTENSION statements.
|
||||
func (uc *RestorePostgresqlBackupUsecase) generateFilteredTocList(
|
||||
ctx context.Context,
|
||||
pgBin string,
|
||||
backupFile string,
|
||||
pgpassFile string,
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
) (string, error) {
|
||||
uc.logger.Info("Generating filtered TOC list to exclude extensions", "backupFile", backupFile)
|
||||
|
||||
// Run pg_restore -l to get the TOC list
|
||||
listCmd := exec.CommandContext(ctx, pgBin, "-l", backupFile)
|
||||
uc.setupPgRestoreEnvironment(listCmd, pgpassFile, pgConfig)
|
||||
|
||||
tocOutput, err := listCmd.Output()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate TOC list: %w", err)
|
||||
}
|
||||
|
||||
// Filter out EXTENSION-related lines (both CREATE EXTENSION and COMMENT ON EXTENSION)
|
||||
var filteredLines []string
|
||||
for line := range strings.SplitSeq(string(tocOutput), "\n") {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if trimmedLine == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
upperLine := strings.ToUpper(trimmedLine)
|
||||
|
||||
// Skip lines that contain " EXTENSION " - this catches both:
|
||||
// - CREATE EXTENSION entries: "3420; 0 0 EXTENSION - uuid-ossp"
|
||||
// - COMMENT ON EXTENSION entries: "3462; 0 0 COMMENT - EXTENSION "uuid-ossp""
|
||||
if strings.Contains(upperLine, " EXTENSION ") {
|
||||
uc.logger.Info("Excluding extension-related entry from restore", "tocLine", trimmedLine)
|
||||
continue
|
||||
}
|
||||
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
|
||||
// Write filtered TOC to temporary file
|
||||
tocFile, err := os.CreateTemp("", "pg_restore_toc_*.list")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create TOC list file: %w", err)
|
||||
}
|
||||
tocFilePath := tocFile.Name()
|
||||
|
||||
filteredContent := strings.Join(filteredLines, "\n")
|
||||
if _, err := tocFile.WriteString(filteredContent); err != nil {
|
||||
_ = tocFile.Close()
|
||||
_ = os.Remove(tocFilePath)
|
||||
return "", fmt.Errorf("failed to write TOC list file: %w", err)
|
||||
}
|
||||
|
||||
if err := tocFile.Close(); err != nil {
|
||||
_ = os.Remove(tocFilePath)
|
||||
return "", fmt.Errorf("failed to close TOC list file: %w", err)
|
||||
}
|
||||
|
||||
uc.logger.Info("Generated filtered TOC list file",
|
||||
"tocFile", tocFilePath,
|
||||
"originalLines", len(strings.Split(string(tocOutput), "\n")),
|
||||
"filteredLines", len(filteredLines),
|
||||
)
|
||||
|
||||
return tocFilePath, nil
|
||||
}
|
||||
|
||||
// createTempPgpassFile creates a temporary .pgpass file with the given password
|
||||
func (uc *RestorePostgresqlBackupUsecase) createTempPgpassFile(
|
||||
pgConfig *pgtypes.PostgresqlDatabase,
|
||||
|
||||
@@ -2,16 +2,21 @@ package usecases
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
"postgresus-backend/internal/features/restores/models"
|
||||
usecases_mariadb "postgresus-backend/internal/features/restores/usecases/mariadb"
|
||||
usecases_mysql "postgresus-backend/internal/features/restores/usecases/mysql"
|
||||
usecases_postgresql "postgresus-backend/internal/features/restores/usecases/postgresql"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
)
|
||||
|
||||
type RestoreBackupUsecase struct {
|
||||
restorePostgresqlBackupUsecase *usecases_postgresql.RestorePostgresqlBackupUsecase
|
||||
restoreMysqlBackupUsecase *usecases_mysql.RestoreMysqlBackupUsecase
|
||||
restoreMariadbBackupUsecase *usecases_mariadb.RestoreMariadbBackupUsecase
|
||||
}
|
||||
|
||||
func (uc *RestoreBackupUsecase) Execute(
|
||||
@@ -21,8 +26,10 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
restoringToDB *databases.Database,
|
||||
backup *backups.Backup,
|
||||
storage *storages.Storage,
|
||||
isExcludeExtensions bool,
|
||||
) error {
|
||||
if originalDB.Type == databases.DatabaseTypePostgres {
|
||||
switch originalDB.Type {
|
||||
case databases.DatabaseTypePostgres:
|
||||
return uc.restorePostgresqlBackupUsecase.Execute(
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
@@ -30,8 +37,27 @@ func (uc *RestoreBackupUsecase) Execute(
|
||||
restore,
|
||||
backup,
|
||||
storage,
|
||||
isExcludeExtensions,
|
||||
)
|
||||
case databases.DatabaseTypeMysql:
|
||||
return uc.restoreMysqlBackupUsecase.Execute(
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
restore,
|
||||
backup,
|
||||
storage,
|
||||
)
|
||||
case databases.DatabaseTypeMariadb:
|
||||
return uc.restoreMariadbBackupUsecase.Execute(
|
||||
originalDB,
|
||||
restoringToDB,
|
||||
backupConfig,
|
||||
restore,
|
||||
backup,
|
||||
storage,
|
||||
)
|
||||
default:
|
||||
return errors.New("database type not supported")
|
||||
}
|
||||
|
||||
return errors.New("database type not supported")
|
||||
}
|
||||
|
||||
@@ -8,10 +8,13 @@ import (
|
||||
|
||||
audit_logs "postgresus-backend/internal/features/audit_logs"
|
||||
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
|
||||
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
|
||||
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_middleware "postgresus-backend/internal/features/users/middleware"
|
||||
users_services "postgresus-backend/internal/features/users/services"
|
||||
@@ -738,6 +741,155 @@ func Test_StorageSensitiveDataLifecycle_AllTypes(t *testing.T) {
|
||||
assert.Equal(t, "", storage.GoogleDriveStorage.TokenJSON)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FTP Storage",
|
||||
storageType: StorageTypeFTP,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeFTP,
|
||||
Name: "Test FTP Storage",
|
||||
FTPStorage: &ftp_storage.FTPStorage{
|
||||
Host: "ftp.example.com",
|
||||
Port: 21,
|
||||
Username: "testuser",
|
||||
Password: "original-password",
|
||||
UseSSL: false,
|
||||
Path: "/backups",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeFTP,
|
||||
Name: "Updated FTP Storage",
|
||||
FTPStorage: &ftp_storage.FTPStorage{
|
||||
Host: "ftp2.example.com",
|
||||
Port: 2121,
|
||||
Username: "testuser2",
|
||||
Password: "",
|
||||
UseSSL: true,
|
||||
Path: "/backups2",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.FTPStorage.Password, "enc:"),
|
||||
"Password should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
password, err := encryptor.Decrypt(storage.ID, storage.FTPStorage.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password", password)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.FTPStorage.Password)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SFTP Storage",
|
||||
storageType: StorageTypeSFTP,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeSFTP,
|
||||
Name: "Test SFTP Storage",
|
||||
SFTPStorage: &sftp_storage.SFTPStorage{
|
||||
Host: "sftp.example.com",
|
||||
Port: 22,
|
||||
Username: "testuser",
|
||||
Password: "original-password",
|
||||
PrivateKey: "original-private-key",
|
||||
SkipHostKeyVerify: false,
|
||||
Path: "/backups",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeSFTP,
|
||||
Name: "Updated SFTP Storage",
|
||||
SFTPStorage: &sftp_storage.SFTPStorage{
|
||||
Host: "sftp2.example.com",
|
||||
Port: 2222,
|
||||
Username: "testuser2",
|
||||
Password: "",
|
||||
PrivateKey: "",
|
||||
SkipHostKeyVerify: true,
|
||||
Path: "/backups2",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.SFTPStorage.Password, "enc:"),
|
||||
"Password should be encrypted with 'enc:' prefix")
|
||||
assert.True(t, strings.HasPrefix(storage.SFTPStorage.PrivateKey, "enc:"),
|
||||
"PrivateKey should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
password, err := encryptor.Decrypt(storage.ID, storage.SFTPStorage.Password)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-password", password)
|
||||
|
||||
privateKey, err := encryptor.Decrypt(storage.ID, storage.SFTPStorage.PrivateKey)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "original-private-key", privateKey)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.SFTPStorage.Password)
|
||||
assert.Equal(t, "", storage.SFTPStorage.PrivateKey)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Rclone Storage",
|
||||
storageType: StorageTypeRclone,
|
||||
createStorage: func(workspaceID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeRclone,
|
||||
Name: "Test Rclone Storage",
|
||||
RcloneStorage: &rclone_storage.RcloneStorage{
|
||||
ConfigContent: "[myremote]\ntype = s3\nprovider = AWS\naccess_key_id = test\nsecret_access_key = secret\n",
|
||||
RemotePath: "/backups",
|
||||
},
|
||||
}
|
||||
},
|
||||
updateStorage: func(workspaceID uuid.UUID, storageID uuid.UUID) *Storage {
|
||||
return &Storage{
|
||||
ID: storageID,
|
||||
WorkspaceID: workspaceID,
|
||||
Type: StorageTypeRclone,
|
||||
Name: "Updated Rclone Storage",
|
||||
RcloneStorage: &rclone_storage.RcloneStorage{
|
||||
ConfigContent: "",
|
||||
RemotePath: "/backups2",
|
||||
},
|
||||
}
|
||||
},
|
||||
verifySensitiveData: func(t *testing.T, storage *Storage) {
|
||||
assert.True(t, strings.HasPrefix(storage.RcloneStorage.ConfigContent, "enc:"),
|
||||
"ConfigContent should be encrypted with 'enc:' prefix")
|
||||
|
||||
encryptor := encryption.GetFieldEncryptor()
|
||||
configContent, err := encryptor.Decrypt(
|
||||
storage.ID,
|
||||
storage.RcloneStorage.ConfigContent,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(
|
||||
t,
|
||||
"[myremote]\ntype = s3\nprovider = AWS\naccess_key_id = test\nsecret_access_key = secret\n",
|
||||
configContent,
|
||||
)
|
||||
},
|
||||
verifyHiddenData: func(t *testing.T, storage *Storage) {
|
||||
assert.Equal(t, "", storage.RcloneStorage.ConfigContent)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
@@ -8,4 +8,7 @@ const (
|
||||
StorageTypeGoogleDrive StorageType = "GOOGLE_DRIVE"
|
||||
StorageTypeNAS StorageType = "NAS"
|
||||
StorageTypeAzureBlob StorageType = "AZURE_BLOB"
|
||||
StorageTypeFTP StorageType = "FTP"
|
||||
StorageTypeSFTP StorageType = "SFTP"
|
||||
StorageTypeRclone StorageType = "RCLONE"
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
|
||||
type StorageFileSaver interface {
|
||||
SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
|
||||
@@ -1,14 +1,18 @@
|
||||
package storages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
|
||||
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
|
||||
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -27,15 +31,19 @@ type Storage struct {
|
||||
GoogleDriveStorage *google_drive_storage.GoogleDriveStorage `json:"googleDriveStorage" gorm:"foreignKey:StorageID"`
|
||||
NASStorage *nas_storage.NASStorage `json:"nasStorage" gorm:"foreignKey:StorageID"`
|
||||
AzureBlobStorage *azure_blob_storage.AzureBlobStorage `json:"azureBlobStorage" gorm:"foreignKey:StorageID"`
|
||||
FTPStorage *ftp_storage.FTPStorage `json:"ftpStorage" gorm:"foreignKey:StorageID"`
|
||||
SFTPStorage *sftp_storage.SFTPStorage `json:"sftpStorage" gorm:"foreignKey:StorageID"`
|
||||
RcloneStorage *rclone_storage.RcloneStorage `json:"rcloneStorage" gorm:"foreignKey:StorageID"`
|
||||
}
|
||||
|
||||
func (s *Storage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
err := s.getSpecificStorage().SaveFile(encryptor, logger, fileID, file)
|
||||
err := s.getSpecificStorage().SaveFile(ctx, encryptor, logger, fileID, file)
|
||||
if err != nil {
|
||||
lastSaveError := err.Error()
|
||||
s.LastSaveError = &lastSaveError
|
||||
@@ -107,6 +115,18 @@ func (s *Storage) Update(incoming *Storage) {
|
||||
if s.AzureBlobStorage != nil && incoming.AzureBlobStorage != nil {
|
||||
s.AzureBlobStorage.Update(incoming.AzureBlobStorage)
|
||||
}
|
||||
case StorageTypeFTP:
|
||||
if s.FTPStorage != nil && incoming.FTPStorage != nil {
|
||||
s.FTPStorage.Update(incoming.FTPStorage)
|
||||
}
|
||||
case StorageTypeSFTP:
|
||||
if s.SFTPStorage != nil && incoming.SFTPStorage != nil {
|
||||
s.SFTPStorage.Update(incoming.SFTPStorage)
|
||||
}
|
||||
case StorageTypeRclone:
|
||||
if s.RcloneStorage != nil && incoming.RcloneStorage != nil {
|
||||
s.RcloneStorage.Update(incoming.RcloneStorage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,6 +142,12 @@ func (s *Storage) getSpecificStorage() StorageFileSaver {
|
||||
return s.NASStorage
|
||||
case StorageTypeAzureBlob:
|
||||
return s.AzureBlobStorage
|
||||
case StorageTypeFTP:
|
||||
return s.FTPStorage
|
||||
case StorageTypeSFTP:
|
||||
return s.SFTPStorage
|
||||
case StorageTypeRclone:
|
||||
return s.RcloneStorage
|
||||
default:
|
||||
panic("invalid storage type: " + string(s.Type))
|
||||
}
|
||||
|
||||
@@ -9,10 +9,13 @@ import (
|
||||
"path/filepath"
|
||||
"postgresus-backend/internal/config"
|
||||
azure_blob_storage "postgresus-backend/internal/features/storages/models/azure_blob"
|
||||
ftp_storage "postgresus-backend/internal/features/storages/models/ftp"
|
||||
google_drive_storage "postgresus-backend/internal/features/storages/models/google_drive"
|
||||
local_storage "postgresus-backend/internal/features/storages/models/local"
|
||||
nas_storage "postgresus-backend/internal/features/storages/models/nas"
|
||||
rclone_storage "postgresus-backend/internal/features/storages/models/rclone"
|
||||
s3_storage "postgresus-backend/internal/features/storages/models/s3"
|
||||
sftp_storage "postgresus-backend/internal/features/storages/models/sftp"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"postgresus-backend/internal/util/logger"
|
||||
"strconv"
|
||||
@@ -70,6 +73,22 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Setup FTP port
|
||||
ftpPort := 21
|
||||
if portStr := config.GetEnv().TestFTPPort; portStr != "" {
|
||||
if port, err := strconv.Atoi(portStr); err == nil {
|
||||
ftpPort = port
|
||||
}
|
||||
}
|
||||
|
||||
// Setup SFTP port
|
||||
sftpPort := 22
|
||||
if portStr := config.GetEnv().TestSFTPPort; portStr != "" {
|
||||
if port, err := strconv.Atoi(portStr); err == nil {
|
||||
sftpPort = port
|
||||
}
|
||||
}
|
||||
|
||||
// Run tests
|
||||
testCases := []struct {
|
||||
name string
|
||||
@@ -124,6 +143,44 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
ContainerName: azuriteContainer.containerNameStr,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "FTPStorage",
|
||||
storage: &ftp_storage.FTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Port: ftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
UseSSL: false,
|
||||
Path: "test-files",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SFTPStorage",
|
||||
storage: &sftp_storage.SFTPStorage{
|
||||
StorageID: uuid.New(),
|
||||
Host: "localhost",
|
||||
Port: sftpPort,
|
||||
Username: "testuser",
|
||||
Password: "testpassword",
|
||||
SkipHostKeyVerify: true,
|
||||
Path: "upload",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "RcloneStorage",
|
||||
storage: &rclone_storage.RcloneStorage{
|
||||
StorageID: uuid.New(),
|
||||
ConfigContent: fmt.Sprintf(`[minio]
|
||||
type = s3
|
||||
provider = Other
|
||||
access_key_id = %s
|
||||
secret_access_key = %s
|
||||
endpoint = http://%s
|
||||
acl = private`, s3Container.accessKey, s3Container.secretKey, s3Container.endpoint),
|
||||
RemotePath: s3Container.bucketName,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add Google Drive storage test only if environment variables are available
|
||||
@@ -167,6 +224,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
fileID := uuid.New()
|
||||
|
||||
err = tc.storage.SaveFile(
|
||||
context.Background(),
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
@@ -189,6 +247,7 @@ func Test_Storage_BasicOperations(t *testing.T) {
|
||||
|
||||
fileID := uuid.New()
|
||||
err = tc.storage.SaveFile(
|
||||
context.Background(),
|
||||
encryptor,
|
||||
logger.GetLogger(),
|
||||
fileID,
|
||||
@@ -238,7 +297,7 @@ func setupS3Container(ctx context.Context) (*S3Container, error) {
|
||||
secretKey := "testpassword"
|
||||
bucketName := "test-bucket"
|
||||
region := "us-east-1"
|
||||
endpoint := fmt.Sprintf("localhost:%s", env.TestMinioPort)
|
||||
endpoint := fmt.Sprintf("127.0.0.1:%s", env.TestMinioPort)
|
||||
|
||||
// Create MinIO client and ensure bucket exists
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
|
||||
@@ -3,19 +3,44 @@ package azure_blob_storage
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
|
||||
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/blockblob"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
azureConnectTimeout = 30 * time.Second
|
||||
azureResponseTimeout = 30 * time.Second
|
||||
azureIdleConnTimeout = 90 * time.Second
|
||||
azureTLSHandshakeTimeout = 30 * time.Second
|
||||
|
||||
// Chunk size for block blob uploads - 16MB provides good balance between
|
||||
// memory usage and upload efficiency. This creates backpressure to pg_dump
|
||||
// by only reading one chunk at a time and waiting for Azure to confirm receipt.
|
||||
azureChunkSize = 16 * 1024 * 1024
|
||||
)
|
||||
|
||||
type readSeekCloser struct {
|
||||
*bytes.Reader
|
||||
}
|
||||
|
||||
func (r *readSeekCloser) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type AuthMethod string
|
||||
|
||||
const (
|
||||
@@ -39,27 +64,91 @@ func (s *AzureBlobStorage) TableName() string {
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled before start: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
blobName := s.buildBlobName(fileID.String())
|
||||
blockBlobClient := client.ServiceClient().
|
||||
NewContainerClient(s.ContainerName).
|
||||
NewBlockBlobClient(blobName)
|
||||
|
||||
_, err = client.UploadStream(
|
||||
context.TODO(),
|
||||
s.ContainerName,
|
||||
blobName,
|
||||
file,
|
||||
nil,
|
||||
)
|
||||
var blockIDs []string
|
||||
blockNumber := 0
|
||||
buf := make([]byte, azureChunkSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
n, readErr := io.ReadFull(file, buf)
|
||||
|
||||
if n == 0 && readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
|
||||
return fmt.Errorf("read error: %w", readErr)
|
||||
}
|
||||
|
||||
blockID := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%06d", blockNumber)))
|
||||
|
||||
_, err := blockBlobClient.StageBlock(
|
||||
ctx,
|
||||
blockID,
|
||||
&readSeekCloser{bytes.NewReader(buf[:n])},
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled: %w", ctx.Err())
|
||||
default:
|
||||
return fmt.Errorf("failed to stage block %d: %w", blockNumber, err)
|
||||
}
|
||||
}
|
||||
|
||||
blockIDs = append(blockIDs, blockID)
|
||||
blockNumber++
|
||||
|
||||
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(blockIDs) == 0 {
|
||||
_, err = client.UploadStream(
|
||||
ctx,
|
||||
s.ContainerName,
|
||||
blobName,
|
||||
bytes.NewReader([]byte{}),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload empty blob: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = blockBlobClient.CommitBlockList(ctx, blockIDs, &blockblob.CommitBlockListOptions{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload blob to Azure: %w", err)
|
||||
return fmt.Errorf("failed to commit block list: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -253,6 +342,8 @@ func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azbl
|
||||
var client *azblob.Client
|
||||
var err error
|
||||
|
||||
clientOptions := s.buildClientOptions()
|
||||
|
||||
switch s.AuthMethod {
|
||||
case AuthMethodConnectionString:
|
||||
connectionString, decryptErr := encryptor.Decrypt(s.StorageID, s.ConnectionString)
|
||||
@@ -260,7 +351,7 @@ func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azbl
|
||||
return nil, fmt.Errorf("failed to decrypt Azure connection string: %w", decryptErr)
|
||||
}
|
||||
|
||||
client, err = azblob.NewClientFromConnectionString(connectionString, nil)
|
||||
client, err = azblob.NewClientFromConnectionString(connectionString, clientOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to create Azure Blob client from connection string: %w",
|
||||
@@ -279,7 +370,7 @@ func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azbl
|
||||
return nil, fmt.Errorf("failed to create Azure shared key credential: %w", credErr)
|
||||
}
|
||||
|
||||
client, err = azblob.NewClientWithSharedKeyCredential(accountURL, credential, nil)
|
||||
client, err = azblob.NewClientWithSharedKeyCredential(accountURL, credential, clientOptions)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Azure Blob client with shared key: %w", err)
|
||||
}
|
||||
@@ -290,6 +381,26 @@ func (s *AzureBlobStorage) getClient(encryptor encryption.FieldEncryptor) (*azbl
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) buildClientOptions() *azblob.ClientOptions {
|
||||
transport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: azureConnectTimeout,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: azureTLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: azureResponseTimeout,
|
||||
IdleConnTimeout: azureIdleConnTimeout,
|
||||
}
|
||||
|
||||
return &azblob.ClientOptions{
|
||||
ClientOptions: azcore.ClientOptions{
|
||||
Transport: &http.Client{Transport: transport},
|
||||
Retry: policy.RetryOptions{
|
||||
MaxRetries: 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AzureBlobStorage) buildAccountURL() string {
|
||||
if s.Endpoint != "" {
|
||||
endpoint := s.Endpoint
|
||||
|
||||
368
backend/internal/features/storages/models/ftp/model.go
Normal file
368
backend/internal/features/storages/models/ftp/model.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package ftp_storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jlaffaye/ftp"
|
||||
)
|
||||
|
||||
const (
|
||||
ftpConnectTimeout = 30 * time.Second
|
||||
ftpTestConnectTimeout = 10 * time.Second
|
||||
ftpChunkSize = 16 * 1024 * 1024
|
||||
)
|
||||
|
||||
type FTPStorage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
Host string `json:"host" gorm:"not null;type:text;column:host"`
|
||||
Port int `json:"port" gorm:"not null;default:21;column:port"`
|
||||
Username string `json:"username" gorm:"not null;type:text;column:username"`
|
||||
Password string `json:"password" gorm:"not null;type:text;column:password"`
|
||||
Path string `json:"path" gorm:"type:text;column:path"`
|
||||
UseSSL bool `json:"useSsl" gorm:"not null;default:false;column:use_ssl"`
|
||||
SkipTLSVerify bool `json:"skipTlsVerify" gorm:"not null;default:false;column:skip_tls_verify"`
|
||||
}
|
||||
|
||||
func (f *FTPStorage) TableName() string {
|
||||
return "ftp_storages"
|
||||
}
|
||||
|
||||
func (f *FTPStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to FTP storage", "fileId", fileID.String(), "host", f.Host)
|
||||
|
||||
conn, err := f.connect(encryptor, ftpConnectTimeout)
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to FTP", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if quitErr := conn.Quit(); quitErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close FTP connection",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"error",
|
||||
quitErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
if f.Path != "" {
|
||||
if err := f.ensureDirectory(conn, f.Path); err != nil {
|
||||
logger.Error(
|
||||
"Failed to ensure directory",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"path",
|
||||
f.Path,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return fmt.Errorf("failed to ensure directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file to FTP", "fileId", fileID.String(), "filePath", filePath)
|
||||
|
||||
ctxReader := &contextReader{ctx: ctx, reader: file}
|
||||
|
||||
err = conn.Stor(filePath, ctxReader)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("FTP upload cancelled", "fileId", fileID.String())
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error("Failed to upload file to FTP", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to upload file to FTP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to FTP storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
conn, err := f.connect(encryptor, ftpConnectTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
|
||||
resp, err := conn.Retr(filePath)
|
||||
if err != nil {
|
||||
_ = conn.Quit()
|
||||
return nil, fmt.Errorf("failed to retrieve file from FTP: %w", err)
|
||||
}
|
||||
|
||||
return &ftpFileReader{
|
||||
response: resp,
|
||||
conn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
conn, err := f.connect(encryptor, ftpConnectTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Quit()
|
||||
}()
|
||||
|
||||
filePath := f.getFilePath(fileID.String())
|
||||
|
||||
_, err = conn.FileSize(filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = conn.Delete(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from FTP: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if f.Host == "" {
|
||||
return errors.New("FTP host is required")
|
||||
}
|
||||
if f.Username == "" {
|
||||
return errors.New("FTP username is required")
|
||||
}
|
||||
if f.Password == "" {
|
||||
return errors.New("FTP password is required")
|
||||
}
|
||||
if f.Port <= 0 || f.Port > 65535 {
|
||||
return errors.New("FTP port must be between 1 and 65535")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ftpTestConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := f.connectWithContext(ctx, encryptor, ftpTestConnectTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to FTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Quit()
|
||||
}()
|
||||
|
||||
if f.Path != "" {
|
||||
if err := f.ensureDirectory(conn, f.Path); err != nil {
|
||||
return fmt.Errorf("failed to access or create path '%s': %w", f.Path, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) HideSensitiveData() {
|
||||
f.Password = ""
|
||||
}
|
||||
|
||||
func (f *FTPStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if f.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(f.StorageID, f.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt FTP password: %w", err)
|
||||
}
|
||||
f.Password = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) Update(incoming *FTPStorage) {
|
||||
f.Host = incoming.Host
|
||||
f.Port = incoming.Port
|
||||
f.Username = incoming.Username
|
||||
f.UseSSL = incoming.UseSSL
|
||||
f.SkipTLSVerify = incoming.SkipTLSVerify
|
||||
f.Path = incoming.Path
|
||||
|
||||
if incoming.Password != "" {
|
||||
f.Password = incoming.Password
|
||||
}
|
||||
}
|
||||
|
||||
func (f *FTPStorage) connect(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
timeout time.Duration,
|
||||
) (*ftp.ServerConn, error) {
|
||||
return f.connectWithContext(context.Background(), encryptor, timeout)
|
||||
}
|
||||
|
||||
func (f *FTPStorage) connectWithContext(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
timeout time.Duration,
|
||||
) (*ftp.ServerConn, error) {
|
||||
password, err := encryptor.Decrypt(f.StorageID, f.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt FTP password: %w", err)
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", f.Host, f.Port)
|
||||
|
||||
dialCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
var conn *ftp.ServerConn
|
||||
if f.UseSSL {
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: f.Host,
|
||||
InsecureSkipVerify: f.SkipTLSVerify,
|
||||
}
|
||||
conn, err = ftp.Dial(address,
|
||||
ftp.DialWithContext(dialCtx),
|
||||
ftp.DialWithExplicitTLS(tlsConfig),
|
||||
)
|
||||
} else {
|
||||
conn, err = ftp.Dial(address, ftp.DialWithContext(dialCtx))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to dial FTP server: %w", err)
|
||||
}
|
||||
|
||||
err = conn.Login(f.Username, password)
|
||||
if err != nil {
|
||||
_ = conn.Quit()
|
||||
return nil, fmt.Errorf("failed to login to FTP server: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) ensureDirectory(conn *ftp.ServerConn, path string) error {
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
currentPath := ""
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" || part == "." {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPath == "" {
|
||||
currentPath = part
|
||||
} else {
|
||||
currentPath = currentPath + "/" + part
|
||||
}
|
||||
|
||||
err := conn.ChangeDir(currentPath)
|
||||
if err != nil {
|
||||
err = conn.MakeDir(currentPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = conn.ChangeDirToParent()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to change to parent directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *FTPStorage) getFilePath(filename string) string {
|
||||
if f.Path == "" {
|
||||
return filename
|
||||
}
|
||||
|
||||
path := strings.TrimPrefix(f.Path, "/")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
||||
return path + "/" + filename
|
||||
}
|
||||
|
||||
type ftpFileReader struct {
|
||||
response *ftp.Response
|
||||
conn *ftp.ServerConn
|
||||
}
|
||||
|
||||
func (r *ftpFileReader) Read(p []byte) (n int, err error) {
|
||||
return r.response.Read(p)
|
||||
}
|
||||
|
||||
func (r *ftpFileReader) Close() error {
|
||||
var errs []error
|
||||
|
||||
if r.response != nil {
|
||||
if err := r.response.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close response: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if r.conn != nil {
|
||||
if err := r.conn.Quit(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close connection: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type contextReader struct {
|
||||
ctx context.Context
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (r *contextReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
default:
|
||||
return r.reader.Read(p)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,9 +18,22 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
|
||||
drive "google.golang.org/api/drive/v3"
|
||||
"google.golang.org/api/googleapi"
|
||||
"google.golang.org/api/option"
|
||||
)
|
||||
|
||||
const (
|
||||
gdConnectTimeout = 30 * time.Second
|
||||
gdResponseTimeout = 30 * time.Second
|
||||
gdIdleConnTimeout = 90 * time.Second
|
||||
gdTLSHandshakeTimeout = 30 * time.Second
|
||||
|
||||
// Chunk size for Google Drive resumable uploads - 16MB provides good balance
|
||||
// between memory usage and upload efficiency. Google Drive requires chunks
|
||||
// to be multiples of 256KB for resumable uploads.
|
||||
gdChunkSize = 16 * 1024 * 1024
|
||||
)
|
||||
|
||||
type GoogleDriveStorage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
ClientID string `json:"clientId" gorm:"not null;type:text;column:client_id"`
|
||||
@@ -31,31 +46,44 @@ func (s *GoogleDriveStorage) TableName() string {
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
|
||||
filename := fileID.String()
|
||||
|
||||
// Ensure the postgresus_backups folder exists
|
||||
folderID, err := s.ensureBackupsFolderExists(ctx, driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create/find backups folder: %w", err)
|
||||
}
|
||||
|
||||
// Delete any previous copy so we keep at most one object per logical file.
|
||||
_ = s.deleteByName(ctx, driveService, filename, folderID) // ignore "not found"
|
||||
_ = s.deleteByName(ctx, driveService, filename, folderID)
|
||||
|
||||
fileMeta := &drive.File{
|
||||
Name: filename,
|
||||
Parents: []string{folderID},
|
||||
}
|
||||
|
||||
_, err = driveService.Files.Create(fileMeta).Media(file).Context(ctx).Do()
|
||||
backpressureReader := &backpressureReader{
|
||||
reader: file,
|
||||
ctx: ctx,
|
||||
chunkSize: gdChunkSize,
|
||||
buf: make([]byte, gdChunkSize),
|
||||
}
|
||||
|
||||
_, err = driveService.Files.Create(fileMeta).
|
||||
Media(backpressureReader, googleapi.ChunkSize(gdChunkSize)).
|
||||
Context(ctx).
|
||||
Do()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
return fmt.Errorf("failed to upload file to Google Drive: %w", err)
|
||||
}
|
||||
|
||||
@@ -70,30 +98,85 @@ func (s *GoogleDriveStorage) SaveFile(
|
||||
})
|
||||
}
|
||||
|
||||
type backpressureReader struct {
|
||||
reader io.Reader
|
||||
ctx context.Context
|
||||
chunkSize int
|
||||
buf []byte
|
||||
bufStart int
|
||||
bufEnd int
|
||||
totalBytes int64
|
||||
chunkCount int
|
||||
}
|
||||
|
||||
func (r *backpressureReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if r.bufStart >= r.bufEnd {
|
||||
r.chunkCount++
|
||||
|
||||
bytesRead, readErr := io.ReadFull(r.reader, r.buf)
|
||||
if bytesRead > 0 {
|
||||
r.bufStart = 0
|
||||
r.bufEnd = bytesRead
|
||||
}
|
||||
|
||||
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
|
||||
return 0, readErr
|
||||
}
|
||||
|
||||
if bytesRead == 0 && readErr == io.EOF {
|
||||
return 0, io.EOF
|
||||
}
|
||||
}
|
||||
|
||||
n = copy(p, r.buf[r.bufStart:r.bufEnd])
|
||||
r.bufStart += n
|
||||
r.totalBytes += int64(n)
|
||||
|
||||
if r.bufStart >= r.bufEnd {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return n, r.ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
var result io.ReadCloser
|
||||
err := s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
folderID, err := s.findBackupsFolder(driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
}
|
||||
err := s.withRetryOnAuth(
|
||||
context.Background(),
|
||||
encryptor,
|
||||
func(driveService *drive.Service) error {
|
||||
folderID, err := s.findBackupsFolder(driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
}
|
||||
|
||||
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fileIDGoogle, err := s.lookupFileID(driveService, fileID.String(), folderID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := driveService.Files.Get(fileIDGoogle).Download()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file from Google Drive: %w", err)
|
||||
}
|
||||
resp, err := driveService.Files.Get(fileIDGoogle).Download()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download file from Google Drive: %w", err)
|
||||
}
|
||||
|
||||
result = resp.Body
|
||||
return nil
|
||||
})
|
||||
result = resp.Body
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
return result, err
|
||||
}
|
||||
@@ -102,8 +185,8 @@ func (s *GoogleDriveStorage) DeleteFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
ctx := context.Background()
|
||||
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
|
||||
folderID, err := s.findBackupsFolder(driveService)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find backups folder: %w", err)
|
||||
@@ -142,8 +225,8 @@ func (s *GoogleDriveStorage) Validate(encryptor encryption.FieldEncryptor) error
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
return s.withRetryOnAuth(encryptor, func(driveService *drive.Service) error {
|
||||
ctx := context.Background()
|
||||
ctx := context.Background()
|
||||
return s.withRetryOnAuth(ctx, encryptor, func(driveService *drive.Service) error {
|
||||
testFilename := "test-connection-" + uuid.New().String()
|
||||
testData := []byte("test")
|
||||
|
||||
@@ -243,9 +326,16 @@ func (s *GoogleDriveStorage) Update(incoming *GoogleDriveStorage) {
|
||||
|
||||
// withRetryOnAuth executes the provided function with retry logic for authentication errors
|
||||
func (s *GoogleDriveStorage) withRetryOnAuth(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fn func(*drive.Service) error,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
driveService, err := s.getDriveService(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -253,6 +343,12 @@ func (s *GoogleDriveStorage) withRetryOnAuth(
|
||||
|
||||
err = fn(driveService)
|
||||
if err != nil && s.isAuthError(err) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// Try to refresh token and retry once
|
||||
fmt.Printf("Google Drive auth error detected, attempting token refresh: %v\n", err)
|
||||
|
||||
@@ -422,7 +518,6 @@ func (s *GoogleDriveStorage) getDriveService(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt credentials before use
|
||||
clientSecret, err := encryptor.Decrypt(s.StorageID, s.ClientSecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt Google Drive client secret: %w", err)
|
||||
@@ -449,16 +544,16 @@ func (s *GoogleDriveStorage) getDriveService(
|
||||
|
||||
tokenSource := cfg.TokenSource(ctx, &token)
|
||||
|
||||
// Force token validation to ensure we're using the current token
|
||||
currentToken, err := tokenSource.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current token: %w", err)
|
||||
}
|
||||
|
||||
// Create a new token source with the validated token
|
||||
validatedTokenSource := oauth2.StaticTokenSource(currentToken)
|
||||
|
||||
driveService, err := drive.NewService(ctx, option.WithTokenSource(validatedTokenSource))
|
||||
httpClient := s.buildHTTPClient(validatedTokenSource)
|
||||
|
||||
driveService, err := drive.NewService(ctx, option.WithHTTPClient(httpClient))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to create Drive client: %w", err)
|
||||
}
|
||||
@@ -466,6 +561,24 @@ func (s *GoogleDriveStorage) getDriveService(
|
||||
return driveService, nil
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) buildHTTPClient(tokenSource oauth2.TokenSource) *http.Client {
|
||||
transport := &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: gdConnectTimeout,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: gdTLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: gdResponseTimeout,
|
||||
IdleConnTimeout: gdIdleConnTimeout,
|
||||
}
|
||||
|
||||
return &http.Client{
|
||||
Transport: &oauth2.Transport{
|
||||
Source: tokenSource,
|
||||
Base: transport,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GoogleDriveStorage) lookupFileID(
|
||||
driveService *drive.Service,
|
||||
name string,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package local_storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -13,6 +14,13 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
// Chunk size for local storage writes - 8MB per buffer with double-buffering
|
||||
// allows overlapped I/O while keeping total memory under 32MB.
|
||||
// Two 8MB buffers = 16MB for local storage, plus 8MB for pg_dump buffer = ~25MB total.
|
||||
localChunkSize = 8 * 1024 * 1024
|
||||
)
|
||||
|
||||
// LocalStorage uses ./postgresus_local_backups folder as a
|
||||
// directory for backups and ./postgresus_local_temp folder as a
|
||||
// directory for temp files
|
||||
@@ -25,11 +33,18 @@ func (l *LocalStorage) TableName() string {
|
||||
}
|
||||
|
||||
func (l *LocalStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to local storage", "fileId", fileID.String())
|
||||
|
||||
err := files_utils.EnsureDirectories([]string{
|
||||
@@ -60,7 +75,7 @@ func (l *LocalStorage) SaveFile(
|
||||
}()
|
||||
|
||||
logger.Debug("Copying file data to temp file", "fileId", fileID.String())
|
||||
_, err = io.Copy(tempFile, file)
|
||||
_, err = copyWithContext(ctx, tempFile, file)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write to temp file", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to write to temp file: %w", err)
|
||||
@@ -175,3 +190,35 @@ func (l *LocalStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor)
|
||||
|
||||
func (l *LocalStorage) Update(incoming *LocalStorage) {
|
||||
}
|
||||
|
||||
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
buf := make([]byte, localChunkSize)
|
||||
var written int64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return written, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
nw, writeErr := dst.Write(buf[:nr])
|
||||
written += int64(nw)
|
||||
if writeErr != nil {
|
||||
return written, writeErr
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
|
||||
if readErr == io.EOF {
|
||||
return written, nil
|
||||
}
|
||||
if readErr != nil {
|
||||
return written, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package nas_storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -16,6 +17,13 @@ import (
|
||||
"github.com/hirochachacha/go-smb2"
|
||||
)
|
||||
|
||||
const (
|
||||
// Chunk size for NAS uploads - 16MB provides good balance between
|
||||
// memory usage and upload efficiency. This creates backpressure to pg_dump
|
||||
// by only reading one chunk at a time and waiting for NAS to confirm receipt.
|
||||
nasChunkSize = 16 * 1024 * 1024
|
||||
)
|
||||
|
||||
type NASStorage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
Host string `json:"host" gorm:"not null;type:text;column:host"`
|
||||
@@ -33,14 +41,21 @@ func (n *NASStorage) TableName() string {
|
||||
}
|
||||
|
||||
func (n *NASStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to NAS storage", "fileId", fileID.String(), "host", n.Host)
|
||||
|
||||
session, err := n.createSession(encryptor)
|
||||
session, err := n.createSessionWithContext(ctx, encryptor)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create NAS session", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to create NAS session: %w", err)
|
||||
@@ -121,7 +136,7 @@ func (n *NASStorage) SaveFile(
|
||||
}()
|
||||
|
||||
logger.Debug("Copying file data to NAS", "fileId", fileID.String())
|
||||
_, err = io.Copy(nasFile, file)
|
||||
_, err = copyWithContext(ctx, nasFile, file)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write file to NAS", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to write file to NAS: %w", err)
|
||||
@@ -290,20 +305,24 @@ func (n *NASStorage) Update(incoming *NASStorage) {
|
||||
}
|
||||
|
||||
func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.Session, error) {
|
||||
// Create connection with timeout
|
||||
conn, err := n.createConnection()
|
||||
return n.createSessionWithContext(context.Background(), encryptor)
|
||||
}
|
||||
|
||||
func (n *NASStorage) createSessionWithContext(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (*smb2.Session, error) {
|
||||
conn, err := n.createConnectionWithContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt password before use
|
||||
password, err := encryptor.Decrypt(n.StorageID, n.Password)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("failed to decrypt NAS password: %w", err)
|
||||
}
|
||||
|
||||
// Create SMB2 dialer
|
||||
d := &smb2.Dialer{
|
||||
Initiator: &smb2.NTLMInitiator{
|
||||
User: n.Username,
|
||||
@@ -312,7 +331,6 @@ func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.S
|
||||
},
|
||||
}
|
||||
|
||||
// Create session
|
||||
session, err := d.Dial(conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
@@ -322,34 +340,30 @@ func (n *NASStorage) createSession(encryptor encryption.FieldEncryptor) (*smb2.S
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) createConnection() (net.Conn, error) {
|
||||
func (n *NASStorage) createConnectionWithContext(ctx context.Context) (net.Conn, error) {
|
||||
address := net.JoinHostPort(n.Host, fmt.Sprintf("%d", n.Port))
|
||||
|
||||
// Create connection with timeout
|
||||
dialer := &net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
if n.UseSSL {
|
||||
// Use TLS connection
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: n.Host,
|
||||
InsecureSkipVerify: false, // Change to true if you want to skip cert verification
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
|
||||
conn, err := tls.DialWithDialer(dialer, "tcp", address, tlsConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create SSL connection to %s: %w", address, err)
|
||||
}
|
||||
return conn, nil
|
||||
} else {
|
||||
// Use regular TCP connection
|
||||
conn, err := dialer.Dial("tcp", address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection to %s: %w", address, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create connection to %s: %w", address, err)
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (n *NASStorage) ensureDirectory(fs *smb2.Share, path string) error {
|
||||
@@ -444,3 +458,71 @@ func (r *nasFileReader) Close() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type writeResult struct {
|
||||
bytesWritten int
|
||||
writeErr error
|
||||
}
|
||||
|
||||
func copyWithContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||||
buf := make([]byte, nasChunkSize)
|
||||
var written int64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return written, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
nr, readErr := io.ReadFull(src, buf)
|
||||
|
||||
if nr == 0 && readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
|
||||
return written, readErr
|
||||
}
|
||||
|
||||
writeResultCh := make(chan writeResult, 1)
|
||||
go func() {
|
||||
nw, writeErr := dst.Write(buf[0:nr])
|
||||
writeResultCh <- writeResult{nw, writeErr}
|
||||
}()
|
||||
|
||||
var nw int
|
||||
var writeErr error
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return written, ctx.Err()
|
||||
case result := <-writeResultCh:
|
||||
nw = result.bytesWritten
|
||||
writeErr = result.writeErr
|
||||
}
|
||||
|
||||
if nw < 0 || nr < nw {
|
||||
nw = 0
|
||||
if writeErr == nil {
|
||||
writeErr = errors.New("invalid write result")
|
||||
}
|
||||
}
|
||||
|
||||
if writeErr != nil {
|
||||
return written, writeErr
|
||||
}
|
||||
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
|
||||
written += int64(nw)
|
||||
|
||||
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return written, nil
|
||||
}
|
||||
|
||||
293
backend/internal/features/storages/models/rclone/model.go
Normal file
293
backend/internal/features/storages/models/rclone/model.go
Normal file
@@ -0,0 +1,293 @@
|
||||
package rclone_storage
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rclone/rclone/fs"
|
||||
"github.com/rclone/rclone/fs/config"
|
||||
"github.com/rclone/rclone/fs/operations"
|
||||
|
||||
_ "github.com/rclone/rclone/backend/all"
|
||||
)
|
||||
|
||||
const (
|
||||
rcloneOperationTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
var rcloneConfigMu sync.Mutex
|
||||
|
||||
type RcloneStorage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
ConfigContent string `json:"configContent" gorm:"not null;type:text;column:config_content"`
|
||||
RemotePath string `json:"remotePath" gorm:"type:text;column:remote_path"`
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) TableName() string {
|
||||
return "rclone_storages"
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to rclone storage", "fileId", fileID.String())
|
||||
|
||||
remoteFs, err := r.getFs(ctx, encryptor)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create rclone filesystem", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file via rclone", "fileId", fileID.String(), "filePath", filePath)
|
||||
|
||||
_, err = operations.Rcat(ctx, remoteFs, filePath, io.NopCloser(file), time.Now().UTC(), nil)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("Rclone upload cancelled", "fileId", fileID.String())
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error(
|
||||
"Failed to upload file via rclone",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return fmt.Errorf("failed to upload file via rclone: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to rclone storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
ctx := context.Background()
|
||||
|
||||
remoteFs, err := r.getFs(ctx, encryptor)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
|
||||
obj, err := remoteFs.NewObject(ctx, filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get object from rclone: %w", err)
|
||||
}
|
||||
|
||||
reader, err := obj.Open(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open object from rclone: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
ctx := context.Background()
|
||||
|
||||
remoteFs, err := r.getFs(ctx, encryptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
filePath := r.getFilePath(fileID.String())
|
||||
|
||||
obj, err := remoteFs.NewObject(ctx, filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = obj.Remove(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from rclone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if r.ConfigContent == "" {
|
||||
return errors.New("rclone config content is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), rcloneOperationTimeout)
|
||||
defer cancel()
|
||||
|
||||
remoteFs, err := r.getFs(ctx, encryptor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create rclone filesystem: %w", err)
|
||||
}
|
||||
|
||||
testFileID := uuid.New().String() + "-test"
|
||||
testFilePath := r.getFilePath(testFileID)
|
||||
testData := strings.NewReader("test connection")
|
||||
|
||||
_, err = operations.Rcat(
|
||||
ctx,
|
||||
remoteFs,
|
||||
testFilePath,
|
||||
io.NopCloser(testData),
|
||||
time.Now().UTC(),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload test file via rclone: %w", err)
|
||||
}
|
||||
|
||||
obj, err := remoteFs.NewObject(ctx, testFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get test file from rclone: %w", err)
|
||||
}
|
||||
|
||||
err = obj.Remove(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete test file from rclone: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) HideSensitiveData() {
|
||||
r.ConfigContent = ""
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if r.ConfigContent != "" {
|
||||
encrypted, err := encryptor.Encrypt(r.StorageID, r.ConfigContent)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt rclone config content: %w", err)
|
||||
}
|
||||
r.ConfigContent = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) Update(incoming *RcloneStorage) {
|
||||
r.RemotePath = incoming.RemotePath
|
||||
|
||||
if incoming.ConfigContent != "" {
|
||||
r.ConfigContent = incoming.ConfigContent
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) getFs(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (fs.Fs, error) {
|
||||
configContent, err := encryptor.Decrypt(r.StorageID, r.ConfigContent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt rclone config content: %w", err)
|
||||
}
|
||||
|
||||
rcloneConfigMu.Lock()
|
||||
defer rcloneConfigMu.Unlock()
|
||||
|
||||
parsedConfig, err := parseConfigContent(configContent)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse rclone config: %w", err)
|
||||
}
|
||||
|
||||
if len(parsedConfig) == 0 {
|
||||
return nil, errors.New("rclone config must contain at least one remote section")
|
||||
}
|
||||
|
||||
var remoteName string
|
||||
for section, values := range parsedConfig {
|
||||
remoteName = section
|
||||
for key, value := range values {
|
||||
config.FileSetValue(section, key, value)
|
||||
}
|
||||
}
|
||||
|
||||
remotePath := remoteName + ":"
|
||||
if r.RemotePath != "" {
|
||||
remotePath = remoteName + ":" + strings.TrimPrefix(r.RemotePath, "/")
|
||||
}
|
||||
|
||||
remoteFs, err := fs.NewFs(ctx, remotePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf(
|
||||
"failed to create rclone filesystem for remote '%s': %w",
|
||||
remoteName,
|
||||
err,
|
||||
)
|
||||
}
|
||||
|
||||
return remoteFs, nil
|
||||
}
|
||||
|
||||
func (r *RcloneStorage) getFilePath(filename string) string {
|
||||
return filename
|
||||
}
|
||||
|
||||
func parseConfigContent(content string) (map[string]map[string]string, error) {
|
||||
sections := make(map[string]map[string]string)
|
||||
|
||||
var currentSection string
|
||||
scanner := bufio.NewScanner(strings.NewReader(content))
|
||||
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
|
||||
if line == "" || strings.HasPrefix(line, "#") || strings.HasPrefix(line, ";") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
|
||||
currentSection = strings.TrimPrefix(strings.TrimSuffix(line, "]"), "[")
|
||||
if sections[currentSection] == nil {
|
||||
sections[currentSection] = make(map[string]string)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if currentSection != "" && strings.Contains(line, "=") {
|
||||
parts := strings.SplitN(line, "=", 2)
|
||||
key := strings.TrimSpace(parts[0])
|
||||
value := ""
|
||||
if len(parts) > 1 {
|
||||
value = strings.TrimSpace(parts[1])
|
||||
}
|
||||
sections[currentSection][key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return sections, scanner.Err()
|
||||
}
|
||||
@@ -3,10 +3,13 @@ package s3_storage
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -16,6 +19,18 @@ import (
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
const (
|
||||
s3ConnectTimeout = 30 * time.Second
|
||||
s3ResponseTimeout = 30 * time.Second
|
||||
s3IdleConnTimeout = 90 * time.Second
|
||||
s3TLSHandshakeTimeout = 30 * time.Second
|
||||
|
||||
// Chunk size for multipart uploads - 16MB provides good balance between
|
||||
// memory usage and upload efficiency. This creates backpressure to pg_dump
|
||||
// by only reading one chunk at a time and waiting for S3 to confirm receipt.
|
||||
multipartChunkSize = 16 * 1024 * 1024
|
||||
)
|
||||
|
||||
type S3Storage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
S3Bucket string `json:"s3Bucket" gorm:"not null;type:text;column:s3_bucket"`
|
||||
@@ -26,6 +41,7 @@ type S3Storage struct {
|
||||
|
||||
S3Prefix string `json:"s3Prefix" gorm:"type:text;column:s3_prefix"`
|
||||
S3UseVirtualHostedStyle bool `json:"s3UseVirtualHostedStyle" gorm:"default:false;column:s3_use_virtual_hosted_style"`
|
||||
SkipTLSVerify bool `json:"skipTLSVerify" gorm:"default:false;column:skip_tls_verify"`
|
||||
}
|
||||
|
||||
func (s *S3Storage) TableName() string {
|
||||
@@ -33,29 +49,123 @@ func (s *S3Storage) TableName() string {
|
||||
}
|
||||
|
||||
func (s *S3Storage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
client, err := s.getClient(encryptor)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled before start: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
coreClient, err := s.getCoreClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
objectKey := s.buildObjectKey(fileID.String())
|
||||
|
||||
// Upload the file using MinIO client with streaming (size = -1 for unknown size)
|
||||
_, err = client.PutObject(
|
||||
context.TODO(),
|
||||
uploadID, err := coreClient.NewMultipartUpload(
|
||||
ctx,
|
||||
s.S3Bucket,
|
||||
objectKey,
|
||||
file,
|
||||
-1,
|
||||
minio.PutObjectOptions{},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload file to S3: %w", err)
|
||||
return fmt.Errorf("failed to initiate multipart upload: %w", err)
|
||||
}
|
||||
|
||||
var parts []minio.CompletePart
|
||||
partNumber := 1
|
||||
buf := make([]byte, multipartChunkSize)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
|
||||
return fmt.Errorf("upload cancelled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
|
||||
n, readErr := io.ReadFull(file, buf)
|
||||
|
||||
if n == 0 && readErr == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if readErr != nil && readErr != io.EOF && readErr != io.ErrUnexpectedEOF {
|
||||
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
|
||||
return fmt.Errorf("read error: %w", readErr)
|
||||
}
|
||||
|
||||
part, err := coreClient.PutObjectPart(
|
||||
ctx,
|
||||
s.S3Bucket,
|
||||
objectKey,
|
||||
uploadID,
|
||||
partNumber,
|
||||
bytes.NewReader(buf[:n]),
|
||||
int64(n),
|
||||
minio.PutObjectPartOptions{},
|
||||
)
|
||||
if err != nil {
|
||||
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("upload cancelled: %w", ctx.Err())
|
||||
default:
|
||||
return fmt.Errorf("failed to upload part %d: %w", partNumber, err)
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, minio.CompletePart{
|
||||
PartNumber: partNumber,
|
||||
ETag: part.ETag,
|
||||
})
|
||||
|
||||
partNumber++
|
||||
|
||||
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
|
||||
|
||||
client, err := s.getClient(encryptor)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = client.PutObject(
|
||||
ctx,
|
||||
s.S3Bucket,
|
||||
objectKey,
|
||||
bytes.NewReader([]byte{}),
|
||||
0,
|
||||
minio.PutObjectOptions{},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload empty file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = coreClient.CompleteMultipartUpload(
|
||||
ctx,
|
||||
s.S3Bucket,
|
||||
objectKey,
|
||||
uploadID,
|
||||
parts,
|
||||
minio.PutObjectOptions{},
|
||||
)
|
||||
if err != nil {
|
||||
_ = coreClient.AbortMultipartUpload(ctx, s.S3Bucket, objectKey, uploadID)
|
||||
return fmt.Errorf("failed to complete multipart upload: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -223,6 +333,7 @@ func (s *S3Storage) Update(incoming *S3Storage) {
|
||||
s.S3Region = incoming.S3Region
|
||||
s.S3Endpoint = incoming.S3Endpoint
|
||||
s.S3UseVirtualHostedStyle = incoming.S3UseVirtualHostedStyle
|
||||
s.SkipTLSVerify = incoming.SkipTLSVerify
|
||||
|
||||
if incoming.S3AccessKey != "" {
|
||||
s.S3AccessKey = incoming.S3AccessKey
|
||||
@@ -252,8 +363,54 @@ func (s *S3Storage) buildObjectKey(fileName string) string {
|
||||
}
|
||||
|
||||
func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Client, error) {
|
||||
endpoint := s.S3Endpoint
|
||||
useSSL := true
|
||||
endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, err := s.getClientParams(
|
||||
encryptor,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
|
||||
Secure: useSSL,
|
||||
Region: s.S3Region,
|
||||
BucketLookup: bucketLookup,
|
||||
Transport: transport,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize MinIO client: %w", err)
|
||||
}
|
||||
|
||||
return minioClient, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) getCoreClient(encryptor encryption.FieldEncryptor) (*minio.Core, error) {
|
||||
endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, err := s.getClientParams(
|
||||
encryptor,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
coreClient, err := minio.NewCore(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
|
||||
Secure: useSSL,
|
||||
Region: s.S3Region,
|
||||
BucketLookup: bucketLookup,
|
||||
Transport: transport,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize MinIO Core client: %w", err)
|
||||
}
|
||||
|
||||
return coreClient, nil
|
||||
}
|
||||
|
||||
func (s *S3Storage) getClientParams(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
) (endpoint string, useSSL bool, accessKey string, secretKey string, bucketLookup minio.BucketLookupType, transport *http.Transport, err error) {
|
||||
endpoint = s.S3Endpoint
|
||||
useSSL = true
|
||||
|
||||
if strings.HasPrefix(endpoint, "http://") {
|
||||
useSSL = false
|
||||
@@ -262,38 +419,36 @@ func (s *S3Storage) getClient(encryptor encryption.FieldEncryptor) (*minio.Clien
|
||||
endpoint = strings.TrimPrefix(endpoint, "https://")
|
||||
}
|
||||
|
||||
// If no endpoint is provided, use the AWS S3 endpoint for the region
|
||||
if endpoint == "" {
|
||||
endpoint = fmt.Sprintf("s3.%s.amazonaws.com", s.S3Region)
|
||||
}
|
||||
|
||||
// Decrypt credentials before use
|
||||
accessKey, err := encryptor.Decrypt(s.StorageID, s.S3AccessKey)
|
||||
accessKey, err = encryptor.Decrypt(s.StorageID, s.S3AccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
|
||||
return "", false, "", "", 0, nil, fmt.Errorf("failed to decrypt S3 access key: %w", err)
|
||||
}
|
||||
|
||||
secretKey, err := encryptor.Decrypt(s.StorageID, s.S3SecretKey)
|
||||
secretKey, err = encryptor.Decrypt(s.StorageID, s.S3SecretKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
|
||||
return "", false, "", "", 0, nil, fmt.Errorf("failed to decrypt S3 secret key: %w", err)
|
||||
}
|
||||
|
||||
// Configure bucket lookup strategy
|
||||
bucketLookup := minio.BucketLookupAuto
|
||||
bucketLookup = minio.BucketLookupAuto
|
||||
if s.S3UseVirtualHostedStyle {
|
||||
bucketLookup = minio.BucketLookupDNS
|
||||
}
|
||||
|
||||
// Initialize the MinIO client
|
||||
minioClient, err := minio.New(endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(accessKey, secretKey, ""),
|
||||
Secure: useSSL,
|
||||
Region: s.S3Region,
|
||||
BucketLookup: bucketLookup,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize MinIO client: %w", err)
|
||||
transport = &http.Transport{
|
||||
DialContext: (&net.Dialer{
|
||||
Timeout: s3ConnectTimeout,
|
||||
}).DialContext,
|
||||
TLSHandshakeTimeout: s3TLSHandshakeTimeout,
|
||||
ResponseHeaderTimeout: s3ResponseTimeout,
|
||||
IdleConnTimeout: s3IdleConnTimeout,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: s.SkipTLSVerify,
|
||||
},
|
||||
}
|
||||
|
||||
return minioClient, nil
|
||||
return endpoint, useSSL, accessKey, secretKey, bucketLookup, transport, nil
|
||||
}
|
||||
|
||||
430
backend/internal/features/storages/models/sftp/model.go
Normal file
430
backend/internal/features/storages/models/sftp/model.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package sftp_storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"postgresus-backend/internal/util/encryption"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/sftp"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
sftpConnectTimeout = 30 * time.Second
|
||||
sftpTestConnectTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
type SFTPStorage struct {
|
||||
StorageID uuid.UUID `json:"storageId" gorm:"primaryKey;type:uuid;column:storage_id"`
|
||||
Host string `json:"host" gorm:"not null;type:text;column:host"`
|
||||
Port int `json:"port" gorm:"not null;default:22;column:port"`
|
||||
Username string `json:"username" gorm:"not null;type:text;column:username"`
|
||||
Password string `json:"password" gorm:"type:text;column:password"`
|
||||
PrivateKey string `json:"privateKey" gorm:"type:text;column:private_key"`
|
||||
Path string `json:"path" gorm:"type:text;column:path"`
|
||||
SkipHostKeyVerify bool `json:"skipHostKeyVerify" gorm:"not null;default:false;column:skip_host_key_verify"`
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) TableName() string {
|
||||
return "sftp_storages"
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) SaveFile(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
logger *slog.Logger,
|
||||
fileID uuid.UUID,
|
||||
file io.Reader,
|
||||
) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.Info("Starting to save file to SFTP storage", "fileId", fileID.String(), "host", s.Host)
|
||||
|
||||
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to SFTP", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := client.Close(); closeErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close SFTP client",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
if closeErr := sshConn.Close(); closeErr != nil {
|
||||
logger.Error(
|
||||
"Failed to close SSH connection",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"error",
|
||||
closeErr,
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
if s.Path != "" {
|
||||
if err := s.ensureDirectory(client, s.Path); err != nil {
|
||||
logger.Error(
|
||||
"Failed to ensure directory",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"path",
|
||||
s.Path,
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return fmt.Errorf("failed to ensure directory: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
logger.Debug("Uploading file to SFTP", "fileId", fileID.String(), "filePath", filePath)
|
||||
|
||||
remoteFile, err := client.Create(filePath)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create remote file", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to create remote file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = remoteFile.Close()
|
||||
}()
|
||||
|
||||
ctxReader := &contextReader{ctx: ctx, reader: file}
|
||||
|
||||
_, err = io.Copy(remoteFile, ctxReader)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Info("SFTP upload cancelled", "fileId", fileID.String())
|
||||
return ctx.Err()
|
||||
default:
|
||||
logger.Error("Failed to upload file to SFTP", "fileId", fileID.String(), "error", err)
|
||||
return fmt.Errorf("failed to upload file to SFTP: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Successfully saved file to SFTP storage",
|
||||
"fileId",
|
||||
fileID.String(),
|
||||
"filePath",
|
||||
filePath,
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) GetFile(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
fileID uuid.UUID,
|
||||
) (io.ReadCloser, error) {
|
||||
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
|
||||
remoteFile, err := client.Open(filePath)
|
||||
if err != nil {
|
||||
_ = client.Close()
|
||||
_ = sshConn.Close()
|
||||
return nil, fmt.Errorf("failed to open file from SFTP: %w", err)
|
||||
}
|
||||
|
||||
return &sftpFileReader{
|
||||
file: remoteFile,
|
||||
client: client,
|
||||
sshConn: sshConn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) DeleteFile(encryptor encryption.FieldEncryptor, fileID uuid.UUID) error {
|
||||
client, sshConn, err := s.connect(encryptor, sftpConnectTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
_ = sshConn.Close()
|
||||
}()
|
||||
|
||||
filePath := s.getFilePath(fileID.String())
|
||||
|
||||
_, err = client.Stat(filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Remove(filePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete file from SFTP: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) Validate(encryptor encryption.FieldEncryptor) error {
|
||||
if s.Host == "" {
|
||||
return errors.New("SFTP host is required")
|
||||
}
|
||||
if s.Username == "" {
|
||||
return errors.New("SFTP username is required")
|
||||
}
|
||||
if s.Password == "" && s.PrivateKey == "" {
|
||||
return errors.New("SFTP password or private key is required")
|
||||
}
|
||||
if s.Port <= 0 || s.Port > 65535 {
|
||||
return errors.New("SFTP port must be between 1 and 65535")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) TestConnection(encryptor encryption.FieldEncryptor) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sftpTestConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
client, sshConn, err := s.connectWithContext(ctx, encryptor, sftpTestConnectTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to connect to SFTP: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close()
|
||||
_ = sshConn.Close()
|
||||
}()
|
||||
|
||||
if s.Path != "" {
|
||||
if err := s.ensureDirectory(client, s.Path); err != nil {
|
||||
return fmt.Errorf("failed to access or create path '%s': %w", s.Path, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) HideSensitiveData() {
|
||||
s.Password = ""
|
||||
s.PrivateKey = ""
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) EncryptSensitiveData(encryptor encryption.FieldEncryptor) error {
|
||||
if s.Password != "" {
|
||||
encrypted, err := encryptor.Encrypt(s.StorageID, s.Password)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt SFTP password: %w", err)
|
||||
}
|
||||
s.Password = encrypted
|
||||
}
|
||||
|
||||
if s.PrivateKey != "" {
|
||||
encrypted, err := encryptor.Encrypt(s.StorageID, s.PrivateKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt SFTP private key: %w", err)
|
||||
}
|
||||
s.PrivateKey = encrypted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) Update(incoming *SFTPStorage) {
|
||||
s.Host = incoming.Host
|
||||
s.Port = incoming.Port
|
||||
s.Username = incoming.Username
|
||||
s.SkipHostKeyVerify = incoming.SkipHostKeyVerify
|
||||
s.Path = incoming.Path
|
||||
|
||||
if incoming.Password != "" {
|
||||
s.Password = incoming.Password
|
||||
}
|
||||
|
||||
if incoming.PrivateKey != "" {
|
||||
s.PrivateKey = incoming.PrivateKey
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) connect(
|
||||
encryptor encryption.FieldEncryptor,
|
||||
timeout time.Duration,
|
||||
) (*sftp.Client, *ssh.Client, error) {
|
||||
return s.connectWithContext(context.Background(), encryptor, timeout)
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) connectWithContext(
|
||||
ctx context.Context,
|
||||
encryptor encryption.FieldEncryptor,
|
||||
timeout time.Duration,
|
||||
) (*sftp.Client, *ssh.Client, error) {
|
||||
var authMethods []ssh.AuthMethod
|
||||
|
||||
if s.Password != "" {
|
||||
password, err := encryptor.Decrypt(s.StorageID, s.Password)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decrypt SFTP password: %w", err)
|
||||
}
|
||||
authMethods = append(authMethods, ssh.Password(password))
|
||||
}
|
||||
|
||||
if s.PrivateKey != "" {
|
||||
privateKey, err := encryptor.Decrypt(s.StorageID, s.PrivateKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to decrypt SFTP private key: %w", err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey([]byte(privateKey))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||
}
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
}
|
||||
|
||||
var hostKeyCallback ssh.HostKeyCallback
|
||||
if s.SkipHostKeyVerify {
|
||||
hostKeyCallback = ssh.InsecureIgnoreHostKey()
|
||||
} else {
|
||||
hostKeyCallback = ssh.InsecureIgnoreHostKey()
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: s.Username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
address := fmt.Sprintf("%s:%d", s.Host, s.Port)
|
||||
|
||||
dialer := net.Dialer{Timeout: timeout}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", address)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to dial SFTP server: %w", err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := ssh.NewClientConn(conn, address, config)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SSH connection: %w", err)
|
||||
}
|
||||
|
||||
sshClient := ssh.NewClient(sshConn, chans, reqs)
|
||||
|
||||
sftpClient, err := sftp.NewClient(sshClient)
|
||||
if err != nil {
|
||||
_ = sshClient.Close()
|
||||
return nil, nil, fmt.Errorf("failed to create SFTP client: %w", err)
|
||||
}
|
||||
|
||||
return sftpClient, sshClient, nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) ensureDirectory(client *sftp.Client, path string) error {
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(path, "/")
|
||||
currentPath := ""
|
||||
|
||||
for _, part := range parts {
|
||||
if part == "" || part == "." {
|
||||
continue
|
||||
}
|
||||
|
||||
if currentPath == "" {
|
||||
currentPath = "/" + part
|
||||
} else {
|
||||
currentPath = currentPath + "/" + part
|
||||
}
|
||||
|
||||
_, err := client.Stat(currentPath)
|
||||
if err != nil {
|
||||
err = client.Mkdir(currentPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create directory '%s': %w", currentPath, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SFTPStorage) getFilePath(filename string) string {
|
||||
if s.Path == "" {
|
||||
return filename
|
||||
}
|
||||
|
||||
path := strings.TrimPrefix(s.Path, "/")
|
||||
path = strings.TrimSuffix(path, "/")
|
||||
|
||||
return "/" + path + "/" + filename
|
||||
}
|
||||
|
||||
type sftpFileReader struct {
|
||||
file *sftp.File
|
||||
client *sftp.Client
|
||||
sshConn *ssh.Client
|
||||
}
|
||||
|
||||
func (r *sftpFileReader) Read(p []byte) (n int, err error) {
|
||||
return r.file.Read(p)
|
||||
}
|
||||
|
||||
func (r *sftpFileReader) Close() error {
|
||||
var errs []error
|
||||
|
||||
if r.file != nil {
|
||||
if err := r.file.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close file: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if r.client != nil {
|
||||
if err := r.client.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close SFTP client: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if r.sshConn != nil {
|
||||
if err := r.sshConn.Close(); err != nil {
|
||||
errs = append(errs, fmt.Errorf("failed to close SSH connection: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
if len(errs) > 0 {
|
||||
return errs[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type contextReader struct {
|
||||
ctx context.Context
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func (r *contextReader) Read(p []byte) (n int, err error) {
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return 0, r.ctx.Err()
|
||||
default:
|
||||
return r.reader.Read(p)
|
||||
}
|
||||
}
|
||||
@@ -34,17 +34,29 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
|
||||
if storage.AzureBlobStorage != nil {
|
||||
storage.AzureBlobStorage.StorageID = storage.ID
|
||||
}
|
||||
case StorageTypeFTP:
|
||||
if storage.FTPStorage != nil {
|
||||
storage.FTPStorage.StorageID = storage.ID
|
||||
}
|
||||
case StorageTypeSFTP:
|
||||
if storage.SFTPStorage != nil {
|
||||
storage.SFTPStorage.StorageID = storage.ID
|
||||
}
|
||||
case StorageTypeRclone:
|
||||
if storage.RcloneStorage != nil {
|
||||
storage.RcloneStorage.StorageID = storage.ID
|
||||
}
|
||||
}
|
||||
|
||||
if storage.ID == uuid.Nil {
|
||||
if err := tx.Create(storage).
|
||||
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage").
|
||||
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage", "FTPStorage", "SFTPStorage", "RcloneStorage").
|
||||
Error; err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := tx.Save(storage).
|
||||
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage").
|
||||
Omit("LocalStorage", "S3Storage", "GoogleDriveStorage", "NASStorage", "AzureBlobStorage", "FTPStorage", "SFTPStorage", "RcloneStorage").
|
||||
Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -86,6 +98,27 @@ func (r *StorageRepository) Save(storage *Storage) (*Storage, error) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeFTP:
|
||||
if storage.FTPStorage != nil {
|
||||
storage.FTPStorage.StorageID = storage.ID // Ensure ID is set
|
||||
if err := tx.Save(storage.FTPStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeSFTP:
|
||||
if storage.SFTPStorage != nil {
|
||||
storage.SFTPStorage.StorageID = storage.ID // Ensure ID is set
|
||||
if err := tx.Save(storage.SFTPStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeRclone:
|
||||
if storage.RcloneStorage != nil {
|
||||
storage.RcloneStorage.StorageID = storage.ID // Ensure ID is set
|
||||
if err := tx.Save(storage.RcloneStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -108,6 +141,9 @@ func (r *StorageRepository) FindByID(id uuid.UUID) (*Storage, error) {
|
||||
Preload("GoogleDriveStorage").
|
||||
Preload("NASStorage").
|
||||
Preload("AzureBlobStorage").
|
||||
Preload("FTPStorage").
|
||||
Preload("SFTPStorage").
|
||||
Preload("RcloneStorage").
|
||||
Where("id = ?", id).
|
||||
First(&s).Error; err != nil {
|
||||
return nil, err
|
||||
@@ -126,6 +162,9 @@ func (r *StorageRepository) FindByWorkspaceID(workspaceID uuid.UUID) ([]*Storage
|
||||
Preload("GoogleDriveStorage").
|
||||
Preload("NASStorage").
|
||||
Preload("AzureBlobStorage").
|
||||
Preload("FTPStorage").
|
||||
Preload("SFTPStorage").
|
||||
Preload("RcloneStorage").
|
||||
Where("workspace_id = ?", workspaceID).
|
||||
Order("name ASC").
|
||||
Find(&storages).Error; err != nil {
|
||||
@@ -169,6 +208,24 @@ func (r *StorageRepository) Delete(s *Storage) error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeFTP:
|
||||
if s.FTPStorage != nil {
|
||||
if err := tx.Delete(s.FTPStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeSFTP:
|
||||
if s.SFTPStorage != nil {
|
||||
if err := tx.Delete(s.SFTPStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
case StorageTypeRclone:
|
||||
if s.RcloneStorage != nil {
|
||||
if err := tx.Delete(s.RcloneStorage).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Delete the main storage
|
||||
|
||||
706
backend/internal/features/tests/mariadb_backup_restore_test.go
Normal file
706
backend/internal/features/tests/mariadb_backup_restore_test.go
Normal file
@@ -0,0 +1,706 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mariadbtypes "postgresus-backend/internal/features/databases/databases/mariadb"
|
||||
"postgresus-backend/internal/features/restores"
|
||||
restores_enums "postgresus-backend/internal/features/restores/enums"
|
||||
restores_models "postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
const dropMariadbTestTableQuery = `DROP TABLE IF EXISTS test_data`
|
||||
|
||||
const createMariadbTestTableQuery = `
|
||||
CREATE TABLE test_data (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`
|
||||
|
||||
const insertMariadbTestDataQuery = `
|
||||
INSERT INTO test_data (name, value) VALUES
|
||||
('test1', 100),
|
||||
('test2', 200),
|
||||
('test3', 300)`
|
||||
|
||||
type MariadbContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MariadbVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
type MariadbTestDataItem struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Value int `db:"value"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadb_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMariadbBackupRestoreForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadbWithEncryption_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMariadbBackupRestoreWithEncryptionForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMariadb_WithReadOnlyUser_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MariadbVersion
|
||||
port string
|
||||
}{
|
||||
{"MariaDB 5.5", tools.MariadbVersion55, env.TestMariadb55Port},
|
||||
{"MariaDB 10.1", tools.MariadbVersion101, env.TestMariadb101Port},
|
||||
{"MariaDB 10.2", tools.MariadbVersion102, env.TestMariadb102Port},
|
||||
{"MariaDB 10.3", tools.MariadbVersion103, env.TestMariadb103Port},
|
||||
{"MariaDB 10.4", tools.MariadbVersion104, env.TestMariadb104Port},
|
||||
{"MariaDB 10.5", tools.MariadbVersion105, env.TestMariadb105Port},
|
||||
{"MariaDB 10.6", tools.MariadbVersion106, env.TestMariadb106Port},
|
||||
{"MariaDB 10.11", tools.MariadbVersion1011, env.TestMariadb1011Port},
|
||||
{"MariaDB 11.4", tools.MariadbVersion114, env.TestMariadb114Port},
|
||||
{"MariaDB 11.8", tools.MariadbVersion118, env.TestMariadb118Port},
|
||||
{"MariaDB 12.0", tools.MariadbVersion120, env.TestMariadb120Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMariadbBackupRestoreWithReadOnlyUserForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMariadbContainer(mariadbVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("MariaDB Test Workspace", user, router)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseViaAPI(
|
||||
t, router, "MariaDB Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreWithEncryptionForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMariadbContainer(mariadbVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MariaDB Encrypted Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseViaAPI(
|
||||
t, router, "MariaDB Encrypted Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionEncrypted, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mariadb_encrypted"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testMariadbBackupRestoreWithReadOnlyUserForVersion(
|
||||
t *testing.T,
|
||||
mariadbVersion tools.MariadbVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMariadbContainer(mariadbVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MariaDB %s test: %v", mariadbVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMariadbTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MariaDB ReadOnly Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMariadbDatabaseViaAPI(
|
||||
t, router, "MariaDB ReadOnly Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
readOnlyUser := createMariadbReadOnlyUserViaAPI(t, router, database.ID, user.Token)
|
||||
assert.NotEmpty(t, readOnlyUser.Username)
|
||||
assert.NotEmpty(t, readOnlyUser.Password)
|
||||
|
||||
updatedDatabase := updateMariadbDatabaseCredentialsViaAPI(
|
||||
t, router, database,
|
||||
readOnlyUser.Username, readOnlyUser.Password,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, updatedDatabase.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mariadb_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMariadbRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMariadbRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMariadbDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+updatedDatabase.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createMariadbDatabaseViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
database string,
|
||||
version tools.MariadbVersion,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypeMariadb,
|
||||
Mariadb: &mariadbtypes.MariadbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("Failed to create MariaDB database. Status: %d, Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createdDatabase databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &createdDatabase); err != nil {
|
||||
t.Fatalf("Failed to unmarshal database response: %v", err)
|
||||
}
|
||||
|
||||
return &createdDatabase
|
||||
}
|
||||
|
||||
func createMariadbRestoreViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
backupID uuid.UUID,
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
database string,
|
||||
version tools.MariadbVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
MariadbDatabase: &mariadbtypes.MariadbDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backupID.String()),
|
||||
"Bearer "+token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
func waitForMariadbRestoreCompletion(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
for {
|
||||
if time.Since(startTime) > timeout {
|
||||
t.Fatalf("Timeout waiting for MariaDB restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s", backupID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&restoresList,
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
}
|
||||
t.Fatalf("MariaDB restore failed: %s", failMsg)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyMariadbDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
|
||||
var originalData []MariadbTestDataItem
|
||||
var restoredData []MariadbTestDataItem
|
||||
|
||||
err := originalDB.Select(
|
||||
&originalData,
|
||||
"SELECT id, name, value, created_at FROM test_data ORDER BY id",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = restoredDB.Select(
|
||||
&restoredData,
|
||||
"SELECT id, name, value, created_at FROM test_data ORDER BY id",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")
|
||||
|
||||
if len(originalData) > 0 && len(restoredData) > 0 {
|
||||
for i := range originalData {
|
||||
assert.Equal(t, originalData[i].ID, restoredData[i].ID, "ID should match")
|
||||
assert.Equal(t, originalData[i].Name, restoredData[i].Name, "Name should match")
|
||||
assert.Equal(t, originalData[i].Value, restoredData[i].Value, "Value should match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectToMariadbContainer(
|
||||
version tools.MariadbVersion,
|
||||
port string,
|
||||
) (*MariadbContainer, error) {
|
||||
if port == "" {
|
||||
return nil, fmt.Errorf("MariaDB %s port not configured", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse port: %w", err)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to MariaDB database: %w", err)
|
||||
}
|
||||
|
||||
return &MariadbContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupMariadbTestData(t *testing.T, db *sqlx.DB) {
|
||||
_, err := db.Exec(dropMariadbTestTableQuery)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(createMariadbTestTableQuery)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(insertMariadbTestDataQuery)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func createMariadbReadOnlyUserViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
) *databases.CreateReadOnlyUserResponse {
|
||||
var database databases.Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&database,
|
||||
)
|
||||
|
||||
var response databases.CreateReadOnlyUserResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create-readonly-user",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
return &response
|
||||
}
|
||||
|
||||
func updateMariadbDatabaseCredentialsViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
database *databases.Database,
|
||||
username string,
|
||||
password string,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
database.Mariadb.Username = username
|
||||
database.Mariadb.Password = password
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Failed to update MariaDB database. Status: %d, Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var updatedDatabase databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &updatedDatabase); err != nil {
|
||||
t.Fatalf("Failed to unmarshal database response: %v", err)
|
||||
}
|
||||
|
||||
return &updatedDatabase
|
||||
}
|
||||
675
backend/internal/features/tests/mysql_backup_restore_test.go
Normal file
675
backend/internal/features/tests/mysql_backup_restore_test.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"postgresus-backend/internal/config"
|
||||
"postgresus-backend/internal/features/backups/backups"
|
||||
backups_config "postgresus-backend/internal/features/backups/config"
|
||||
"postgresus-backend/internal/features/databases"
|
||||
mysqltypes "postgresus-backend/internal/features/databases/databases/mysql"
|
||||
"postgresus-backend/internal/features/restores"
|
||||
restores_enums "postgresus-backend/internal/features/restores/enums"
|
||||
restores_models "postgresus-backend/internal/features/restores/models"
|
||||
"postgresus-backend/internal/features/storages"
|
||||
users_enums "postgresus-backend/internal/features/users/enums"
|
||||
users_testing "postgresus-backend/internal/features/users/testing"
|
||||
workspaces_testing "postgresus-backend/internal/features/workspaces/testing"
|
||||
test_utils "postgresus-backend/internal/util/testing"
|
||||
"postgresus-backend/internal/util/tools"
|
||||
)
|
||||
|
||||
const dropMysqlTestTableQuery = `DROP TABLE IF EXISTS test_data`
|
||||
|
||||
const createMysqlTestTableQuery = `
|
||||
CREATE TABLE test_data (
|
||||
id INT AUTO_INCREMENT PRIMARY KEY,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
value INT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`
|
||||
|
||||
const insertMysqlTestDataQuery = `
|
||||
INSERT INTO test_data (name, value) VALUES
|
||||
('test1', 100),
|
||||
('test2', 200),
|
||||
('test3', 300)`
|
||||
|
||||
type MysqlContainer struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
Database string
|
||||
Version tools.MysqlVersion
|
||||
DB *sqlx.DB
|
||||
}
|
||||
|
||||
type MysqlTestDataItem struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
Value int `db:"value"`
|
||||
CreatedAt time.Time `db:"created_at"`
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMysql_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMysqlBackupRestoreForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMysqlWithEncryption_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMysqlBackupRestoreWithEncryptionForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_BackupAndRestoreMysql_WithReadOnlyUser_RestoreIsSuccessful(t *testing.T) {
|
||||
env := config.GetEnv()
|
||||
cases := []struct {
|
||||
name string
|
||||
version tools.MysqlVersion
|
||||
port string
|
||||
}{
|
||||
{"MySQL 5.7", tools.MysqlVersion57, env.TestMysql57Port},
|
||||
{"MySQL 8.0", tools.MysqlVersion80, env.TestMysql80Port},
|
||||
{"MySQL 8.4", tools.MysqlVersion84, env.TestMysql84Port},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
testMysqlBackupRestoreWithReadOnlyUserForVersion(t, tc.version, tc.port)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testMysqlBackupRestoreForVersion(t *testing.T, mysqlVersion tools.MysqlVersion, port string) {
|
||||
container, err := connectToMysqlContainer(mysqlVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MySQL %s test: %v", mysqlVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMysqlTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace("MySQL Test Workspace", user, router)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMysqlDatabaseViaAPI(
|
||||
t, router, "MySQL Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMysqlRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMysqlDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testMysqlBackupRestoreWithEncryptionForVersion(
|
||||
t *testing.T,
|
||||
mysqlVersion tools.MysqlVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMysqlContainer(mysqlVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MySQL %s test: %v", mysqlVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMysqlTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MySQL Encrypted Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMysqlDatabaseViaAPI(
|
||||
t, router, "MySQL Encrypted Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, database.ID, storage.ID,
|
||||
backups_config.BackupEncryptionEncrypted, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, database.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, database.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
assert.Equal(t, backups_config.BackupEncryptionEncrypted, backup.Encryption)
|
||||
|
||||
newDBName := "restoreddb_mysql_encrypted"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMysqlRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMysqlDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+database.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func testMysqlBackupRestoreWithReadOnlyUserForVersion(
|
||||
t *testing.T,
|
||||
mysqlVersion tools.MysqlVersion,
|
||||
port string,
|
||||
) {
|
||||
container, err := connectToMysqlContainer(mysqlVersion, port)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping MySQL %s test: %v", mysqlVersion, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if container.DB != nil {
|
||||
container.DB.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
setupMysqlTestData(t, container.DB)
|
||||
|
||||
router := createTestRouter()
|
||||
user := users_testing.CreateTestUser(users_enums.UserRoleMember)
|
||||
workspace := workspaces_testing.CreateTestWorkspace(
|
||||
"MySQL ReadOnly Test Workspace",
|
||||
user,
|
||||
router,
|
||||
)
|
||||
|
||||
storage := storages.CreateTestStorage(workspace.ID)
|
||||
|
||||
database := createMysqlDatabaseViaAPI(
|
||||
t, router, "MySQL ReadOnly Test Database", workspace.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, container.Database,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
readOnlyUser := createMysqlReadOnlyUserViaAPI(t, router, database.ID, user.Token)
|
||||
assert.NotEmpty(t, readOnlyUser.Username)
|
||||
assert.NotEmpty(t, readOnlyUser.Password)
|
||||
|
||||
updatedDatabase := updateMysqlDatabaseCredentialsViaAPI(
|
||||
t, router, database,
|
||||
readOnlyUser.Username, readOnlyUser.Password,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
enableBackupsViaAPI(
|
||||
t, router, updatedDatabase.ID, storage.ID,
|
||||
backups_config.BackupEncryptionNone, user.Token,
|
||||
)
|
||||
|
||||
createBackupViaAPI(t, router, updatedDatabase.ID, user.Token)
|
||||
|
||||
backup := waitForBackupCompletion(t, router, updatedDatabase.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, backups.BackupStatusCompleted, backup.Status)
|
||||
|
||||
newDBName := "restoreddb_mysql_readonly"
|
||||
_, err = container.DB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = container.DB.Exec(fmt.Sprintf("CREATE DATABASE %s;", newDBName))
|
||||
assert.NoError(t, err)
|
||||
|
||||
newDSN := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
container.Username, container.Password, container.Host, container.Port, newDBName)
|
||||
newDB, err := sqlx.Connect("mysql", newDSN)
|
||||
assert.NoError(t, err)
|
||||
defer newDB.Close()
|
||||
|
||||
createMysqlRestoreViaAPI(
|
||||
t, router, backup.ID,
|
||||
container.Host, container.Port,
|
||||
container.Username, container.Password, newDBName,
|
||||
container.Version,
|
||||
user.Token,
|
||||
)
|
||||
|
||||
restore := waitForMysqlRestoreCompletion(t, router, backup.ID, user.Token, 5*time.Minute)
|
||||
assert.Equal(t, restores_enums.RestoreStatusCompleted, restore.Status)
|
||||
|
||||
var tableExists int
|
||||
err = newDB.Get(
|
||||
&tableExists,
|
||||
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = 'test_data'",
|
||||
newDBName,
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, tableExists, "Table 'test_data' should exist in restored database")
|
||||
|
||||
verifyMysqlDataIntegrity(t, container.DB, newDB)
|
||||
|
||||
err = os.Remove(filepath.Join(config.GetEnv().DataFolder, backup.ID.String()))
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to delete backup file: %v", err)
|
||||
}
|
||||
|
||||
test_utils.MakeDeleteRequest(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/"+updatedDatabase.ID.String(),
|
||||
"Bearer "+user.Token,
|
||||
http.StatusNoContent,
|
||||
)
|
||||
storages.RemoveTestStorage(storage.ID)
|
||||
workspaces_testing.RemoveTestWorkspace(workspace, router)
|
||||
}
|
||||
|
||||
func createMysqlDatabaseViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
name string,
|
||||
workspaceID uuid.UUID,
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
database string,
|
||||
version tools.MysqlVersion,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
request := databases.Database{
|
||||
Name: name,
|
||||
WorkspaceID: &workspaceID,
|
||||
Type: databases.DatabaseTypeMysql,
|
||||
Mysql: &mysqltypes.MysqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/create",
|
||||
"Bearer "+token,
|
||||
request,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("Failed to create MySQL database. Status: %d, Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var createdDatabase databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &createdDatabase); err != nil {
|
||||
t.Fatalf("Failed to unmarshal database response: %v", err)
|
||||
}
|
||||
|
||||
return &createdDatabase
|
||||
}
|
||||
|
||||
func createMysqlRestoreViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
backupID uuid.UUID,
|
||||
host string,
|
||||
port int,
|
||||
username string,
|
||||
password string,
|
||||
database string,
|
||||
version tools.MysqlVersion,
|
||||
token string,
|
||||
) {
|
||||
request := restores.RestoreBackupRequest{
|
||||
MysqlDatabase: &mysqltypes.MysqlDatabase{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: &database,
|
||||
Version: version,
|
||||
},
|
||||
}
|
||||
|
||||
test_utils.MakePostRequest(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s/restore", backupID.String()),
|
||||
"Bearer "+token,
|
||||
request,
|
||||
http.StatusOK,
|
||||
)
|
||||
}
|
||||
|
||||
func waitForMysqlRestoreCompletion(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
backupID uuid.UUID,
|
||||
token string,
|
||||
timeout time.Duration,
|
||||
) *restores_models.Restore {
|
||||
startTime := time.Now()
|
||||
pollInterval := 500 * time.Millisecond
|
||||
|
||||
for {
|
||||
if time.Since(startTime) > timeout {
|
||||
t.Fatalf("Timeout waiting for MySQL restore completion after %v", timeout)
|
||||
}
|
||||
|
||||
var restoresList []*restores_models.Restore
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/restores/%s", backupID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&restoresList,
|
||||
)
|
||||
|
||||
for _, restore := range restoresList {
|
||||
if restore.Status == restores_enums.RestoreStatusCompleted {
|
||||
return restore
|
||||
}
|
||||
if restore.Status == restores_enums.RestoreStatusFailed {
|
||||
failMsg := "unknown error"
|
||||
if restore.FailMessage != nil {
|
||||
failMsg = *restore.FailMessage
|
||||
}
|
||||
t.Fatalf("MySQL restore failed: %s", failMsg)
|
||||
}
|
||||
}
|
||||
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyMysqlDataIntegrity(t *testing.T, originalDB *sqlx.DB, restoredDB *sqlx.DB) {
|
||||
var originalData []MysqlTestDataItem
|
||||
var restoredData []MysqlTestDataItem
|
||||
|
||||
err := originalDB.Select(
|
||||
&originalData,
|
||||
"SELECT id, name, value, created_at FROM test_data ORDER BY id",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = restoredDB.Select(
|
||||
&restoredData,
|
||||
"SELECT id, name, value, created_at FROM test_data ORDER BY id",
|
||||
)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, len(originalData), len(restoredData), "Should have same number of rows")
|
||||
|
||||
if len(originalData) > 0 && len(restoredData) > 0 {
|
||||
for i := range originalData {
|
||||
assert.Equal(t, originalData[i].ID, restoredData[i].ID, "ID should match")
|
||||
assert.Equal(t, originalData[i].Name, restoredData[i].Name, "Name should match")
|
||||
assert.Equal(t, originalData[i].Value, restoredData[i].Value, "Value should match")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func connectToMysqlContainer(version tools.MysqlVersion, port string) (*MysqlContainer, error) {
|
||||
if port == "" {
|
||||
return nil, fmt.Errorf("MySQL %s port not configured", version)
|
||||
}
|
||||
|
||||
dbName := "testdb"
|
||||
password := "rootpassword"
|
||||
username := "root"
|
||||
host := "127.0.0.1"
|
||||
|
||||
portInt, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse port: %w", err)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?parseTime=true",
|
||||
username, password, host, portInt, dbName)
|
||||
|
||||
db, err := sqlx.Connect("mysql", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to MySQL database: %w", err)
|
||||
}
|
||||
|
||||
return &MysqlContainer{
|
||||
Host: host,
|
||||
Port: portInt,
|
||||
Username: username,
|
||||
Password: password,
|
||||
Database: dbName,
|
||||
Version: version,
|
||||
DB: db,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func setupMysqlTestData(t *testing.T, db *sqlx.DB) {
|
||||
_, err := db.Exec(dropMysqlTestTableQuery)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(createMysqlTestTableQuery)
|
||||
assert.NoError(t, err)
|
||||
|
||||
_, err = db.Exec(insertMysqlTestDataQuery)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func createMysqlReadOnlyUserViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
databaseID uuid.UUID,
|
||||
token string,
|
||||
) *databases.CreateReadOnlyUserResponse {
|
||||
var database databases.Database
|
||||
test_utils.MakeGetRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
fmt.Sprintf("/api/v1/databases/%s", databaseID.String()),
|
||||
"Bearer "+token,
|
||||
http.StatusOK,
|
||||
&database,
|
||||
)
|
||||
|
||||
var response databases.CreateReadOnlyUserResponse
|
||||
test_utils.MakePostRequestAndUnmarshal(
|
||||
t,
|
||||
router,
|
||||
"/api/v1/databases/create-readonly-user",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
http.StatusOK,
|
||||
&response,
|
||||
)
|
||||
|
||||
return &response
|
||||
}
|
||||
|
||||
func updateMysqlDatabaseCredentialsViaAPI(
|
||||
t *testing.T,
|
||||
router *gin.Engine,
|
||||
database *databases.Database,
|
||||
username string,
|
||||
password string,
|
||||
token string,
|
||||
) *databases.Database {
|
||||
database.Mysql.Username = username
|
||||
database.Mysql.Password = password
|
||||
|
||||
w := workspaces_testing.MakeAPIRequest(
|
||||
router,
|
||||
"POST",
|
||||
"/api/v1/databases/update",
|
||||
"Bearer "+token,
|
||||
database,
|
||||
)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Failed to update MySQL database. Status: %d, Body: %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var updatedDatabase databases.Database
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &updatedDatabase); err != nil {
|
||||
t.Fatalf("Failed to unmarshal database response: %v", err)
|
||||
}
|
||||
|
||||
return &updatedDatabase
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
241
backend/internal/util/tools/mariadb.go
Normal file
241
backend/internal/util/tools/mariadb.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
env_utils "postgresus-backend/internal/util/env"
|
||||
)
|
||||
|
||||
type MariadbVersion string
|
||||
|
||||
const (
|
||||
MariadbVersion55 MariadbVersion = "5.5"
|
||||
MariadbVersion101 MariadbVersion = "10.1"
|
||||
MariadbVersion102 MariadbVersion = "10.2"
|
||||
MariadbVersion103 MariadbVersion = "10.3"
|
||||
MariadbVersion104 MariadbVersion = "10.4"
|
||||
MariadbVersion105 MariadbVersion = "10.5"
|
||||
MariadbVersion106 MariadbVersion = "10.6"
|
||||
MariadbVersion1011 MariadbVersion = "10.11"
|
||||
MariadbVersion114 MariadbVersion = "11.4"
|
||||
MariadbVersion118 MariadbVersion = "11.8"
|
||||
MariadbVersion120 MariadbVersion = "12.0"
|
||||
)
|
||||
|
||||
// MariadbClientVersion represents the client tool version to use
|
||||
type MariadbClientVersion string
|
||||
|
||||
const (
|
||||
// MariadbClientLegacy is used for older MariaDB servers (5.5, 10.1) that don't support
|
||||
// the generation_expression column in information_schema.columns
|
||||
MariadbClientLegacy MariadbClientVersion = "10.6"
|
||||
// MariadbClientModern is used for newer MariaDB servers (10.2+)
|
||||
MariadbClientModern MariadbClientVersion = "12.1"
|
||||
)
|
||||
|
||||
type MariadbExecutable string
|
||||
|
||||
const (
|
||||
MariadbExecutableMariadbDump MariadbExecutable = "mariadb-dump"
|
||||
MariadbExecutableMariadb MariadbExecutable = "mariadb"
|
||||
)
|
||||
|
||||
// GetMariadbClientVersionForServer returns the appropriate client version to use
|
||||
// for a given server version. MariaDB 12.1 client uses SQL queries that reference
|
||||
// the generation_expression column which was added in MariaDB 10.2, so older
|
||||
// servers (5.5, 10.1) need the legacy 10.6 client.
|
||||
func GetMariadbClientVersionForServer(serverVersion MariadbVersion) MariadbClientVersion {
|
||||
switch serverVersion {
|
||||
case MariadbVersion55, MariadbVersion101:
|
||||
return MariadbClientLegacy
|
||||
default:
|
||||
return MariadbClientModern
|
||||
}
|
||||
}
|
||||
|
||||
// GetMariadbExecutable returns the full path to a MariaDB executable.
|
||||
// The serverVersion parameter determines which client tools to use:
|
||||
// - For MariaDB 5.5 and 10.1: uses legacy 10.6 client (compatible with older servers)
|
||||
// - For MariaDB 10.2+: uses modern 12.1 client
|
||||
func GetMariadbExecutable(
|
||||
executable MariadbExecutable,
|
||||
serverVersion MariadbVersion,
|
||||
envMode env_utils.EnvMode,
|
||||
mariadbInstallDir string,
|
||||
) string {
|
||||
clientVersion := GetMariadbClientVersionForServer(serverVersion)
|
||||
basePath := getMariadbBasePath(clientVersion, envMode, mariadbInstallDir)
|
||||
executableName := string(executable)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
executableName += ".exe"
|
||||
}
|
||||
|
||||
return filepath.Join(basePath, executableName)
|
||||
}
|
||||
|
||||
// VerifyMariadbInstallation verifies that MariaDB client tools are installed.
|
||||
// MariaDB uses two client versions:
|
||||
// - Legacy (10.6) for older servers (5.5, 10.1)
|
||||
// - Modern (12.1) for newer servers (10.2+)
|
||||
func VerifyMariadbInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
mariadbInstallDir string,
|
||||
) {
|
||||
clientVersions := []MariadbClientVersion{MariadbClientLegacy, MariadbClientModern}
|
||||
|
||||
for _, clientVersion := range clientVersions {
|
||||
binDir := getMariadbBasePath(clientVersion, envMode, mariadbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying MariaDB installation",
|
||||
"clientVersion", clientVersion,
|
||||
"path", binDir,
|
||||
)
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
logger.Warn(
|
||||
"MariaDB bin directory not found. Some MariaDB versions may not be supported. Read ./tools/readme.md for details",
|
||||
"clientVersion",
|
||||
clientVersion,
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
} else {
|
||||
logger.Warn(
|
||||
"MariaDB bin directory not found. Some MariaDB versions may not be supported.",
|
||||
"clientVersion", clientVersion,
|
||||
"path", binDir,
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
requiredCommands := []MariadbExecutable{
|
||||
MariadbExecutableMariadbDump,
|
||||
MariadbExecutableMariadb,
|
||||
}
|
||||
|
||||
for _, cmd := range requiredCommands {
|
||||
// Use a dummy server version that maps to this client version
|
||||
var dummyServerVersion MariadbVersion
|
||||
if clientVersion == MariadbClientLegacy {
|
||||
dummyServerVersion = MariadbVersion55
|
||||
} else {
|
||||
dummyServerVersion = MariadbVersion102
|
||||
}
|
||||
cmdPath := GetMariadbExecutable(cmd, dummyServerVersion, envMode, mariadbInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Checking for MariaDB command",
|
||||
"clientVersion", clientVersion,
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
logger.Warn(
|
||||
"MariaDB command not found. Some MariaDB versions may not be supported. Read ./tools/readme.md for details",
|
||||
"clientVersion",
|
||||
clientVersion,
|
||||
"command",
|
||||
cmd,
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
} else {
|
||||
logger.Warn(
|
||||
"MariaDB command not found. Some MariaDB versions may not be supported.",
|
||||
"clientVersion", clientVersion,
|
||||
"command", cmd,
|
||||
"path", cmdPath,
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info("MariaDB command found", "clientVersion", clientVersion, "command", cmd)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("MariaDB client tools verification completed!")
|
||||
}
|
||||
|
||||
// IsMariadbBackupVersionHigherThanRestoreVersion checks if backup was made with
|
||||
// a newer MariaDB version than the restore target
|
||||
func IsMariadbBackupVersionHigherThanRestoreVersion(
|
||||
backupVersion, restoreVersion MariadbVersion,
|
||||
) bool {
|
||||
versionOrder := map[MariadbVersion]int{
|
||||
MariadbVersion55: 1,
|
||||
MariadbVersion101: 2,
|
||||
MariadbVersion102: 3,
|
||||
MariadbVersion103: 4,
|
||||
MariadbVersion104: 5,
|
||||
MariadbVersion105: 6,
|
||||
MariadbVersion106: 7,
|
||||
MariadbVersion1011: 8,
|
||||
MariadbVersion114: 9,
|
||||
MariadbVersion118: 10,
|
||||
MariadbVersion120: 11,
|
||||
}
|
||||
return versionOrder[backupVersion] > versionOrder[restoreVersion]
|
||||
}
|
||||
|
||||
// GetMariadbVersionEnum converts a version string to MariadbVersion enum
|
||||
func GetMariadbVersionEnum(version string) MariadbVersion {
|
||||
switch version {
|
||||
case "5.5":
|
||||
return MariadbVersion55
|
||||
case "10.1":
|
||||
return MariadbVersion101
|
||||
case "10.2":
|
||||
return MariadbVersion102
|
||||
case "10.3":
|
||||
return MariadbVersion103
|
||||
case "10.4":
|
||||
return MariadbVersion104
|
||||
case "10.5":
|
||||
return MariadbVersion105
|
||||
case "10.6":
|
||||
return MariadbVersion106
|
||||
case "10.11":
|
||||
return MariadbVersion1011
|
||||
case "11.4":
|
||||
return MariadbVersion114
|
||||
case "11.8":
|
||||
return MariadbVersion118
|
||||
case "12.0":
|
||||
return MariadbVersion120
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid mariadb version: %s", version))
|
||||
}
|
||||
}
|
||||
|
||||
// EscapeMariadbPassword escapes special characters for MariaDB .my.cnf file format.
|
||||
func EscapeMariadbPassword(password string) string {
|
||||
password = strings.ReplaceAll(password, "\\", "\\\\")
|
||||
password = strings.ReplaceAll(password, "\"", "\\\"")
|
||||
return password
|
||||
}
|
||||
|
||||
func getMariadbBasePath(
|
||||
clientVersion MariadbClientVersion,
|
||||
envMode env_utils.EnvMode,
|
||||
mariadbInstallDir string,
|
||||
) string {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
// Development: tools/mariadb/mariadb-{version}/bin
|
||||
return filepath.Join(mariadbInstallDir, fmt.Sprintf("mariadb-%s", clientVersion), "bin")
|
||||
}
|
||||
// Production: /usr/local/mariadb-{version}/bin
|
||||
return fmt.Sprintf("/usr/local/mariadb-%s/bin", clientVersion)
|
||||
}
|
||||
214
backend/internal/util/tools/mysql.go
Normal file
214
backend/internal/util/tools/mysql.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package tools
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
env_utils "postgresus-backend/internal/util/env"
|
||||
)
|
||||
|
||||
type MysqlVersion string
|
||||
|
||||
const (
|
||||
MysqlVersion57 MysqlVersion = "5.7"
|
||||
MysqlVersion80 MysqlVersion = "8.0"
|
||||
MysqlVersion84 MysqlVersion = "8.4"
|
||||
)
|
||||
|
||||
type MysqlExecutable string
|
||||
|
||||
const (
|
||||
MysqlExecutableMysqldump MysqlExecutable = "mysqldump"
|
||||
MysqlExecutableMysql MysqlExecutable = "mysql"
|
||||
)
|
||||
|
||||
// GetMysqlExecutable returns the full path to a specific MySQL executable
|
||||
// for the given version. Common executables include: mysqldump, mysql.
|
||||
// On Windows, automatically appends .exe extension.
|
||||
func GetMysqlExecutable(
|
||||
version MysqlVersion,
|
||||
executable MysqlExecutable,
|
||||
envMode env_utils.EnvMode,
|
||||
mysqlInstallDir string,
|
||||
) string {
|
||||
basePath := getMysqlBasePath(version, envMode, mysqlInstallDir)
|
||||
executableName := string(executable)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
executableName += ".exe"
|
||||
}
|
||||
|
||||
return filepath.Join(basePath, executableName)
|
||||
}
|
||||
|
||||
// VerifyMysqlInstallation verifies that MySQL versions 5.7, 8.0, 8.4 are installed
|
||||
// in the current environment. Each version should be installed with the required
|
||||
// client tools (mysqldump, mysql) available.
|
||||
// In development: ./tools/mysql/mysql-{VERSION}/bin
|
||||
// In production: /usr/local/mysql-{VERSION}/bin
|
||||
func VerifyMysqlInstallation(
|
||||
logger *slog.Logger,
|
||||
envMode env_utils.EnvMode,
|
||||
mysqlInstallDir string,
|
||||
) {
|
||||
versions := []MysqlVersion{
|
||||
MysqlVersion57,
|
||||
MysqlVersion80,
|
||||
MysqlVersion84,
|
||||
}
|
||||
|
||||
requiredCommands := []MysqlExecutable{
|
||||
MysqlExecutableMysqldump,
|
||||
MysqlExecutableMysql,
|
||||
}
|
||||
|
||||
for _, version := range versions {
|
||||
binDir := getMysqlBasePath(version, envMode, mysqlInstallDir)
|
||||
|
||||
logger.Info(
|
||||
"Verifying MySQL installation",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
|
||||
if _, err := os.Stat(binDir); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
logger.Warn(
|
||||
"MySQL bin directory not found. MySQL support will be disabled. Read ./tools/readme.md for details",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
} else {
|
||||
logger.Warn(
|
||||
"MySQL bin directory not found. MySQL support will be disabled.",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for _, cmd := range requiredCommands {
|
||||
cmdPath := GetMysqlExecutable(
|
||||
version,
|
||||
cmd,
|
||||
envMode,
|
||||
mysqlInstallDir,
|
||||
)
|
||||
|
||||
logger.Info(
|
||||
"Checking for MySQL command",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
|
||||
if _, err := os.Stat(cmdPath); os.IsNotExist(err) {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
logger.Warn(
|
||||
"MySQL command not found. MySQL support for this version will be disabled. Read ./tools/readme.md for details",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
} else {
|
||||
logger.Warn(
|
||||
"MySQL command not found. MySQL support for this version will be disabled.",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
cmdPath,
|
||||
)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"MySQL command found",
|
||||
"command",
|
||||
cmd,
|
||||
"version",
|
||||
string(version),
|
||||
)
|
||||
}
|
||||
|
||||
logger.Info(
|
||||
"Installation of MySQL verified",
|
||||
"version",
|
||||
string(version),
|
||||
"path",
|
||||
binDir,
|
||||
)
|
||||
}
|
||||
|
||||
logger.Info("MySQL version-specific client tools verification completed!")
|
||||
}
|
||||
|
||||
// IsMysqlBackupVersionHigherThanRestoreVersion checks if backup was made with
|
||||
// a newer MySQL version than the restore target
|
||||
func IsMysqlBackupVersionHigherThanRestoreVersion(
|
||||
backupVersion, restoreVersion MysqlVersion,
|
||||
) bool {
|
||||
versionOrder := map[MysqlVersion]int{
|
||||
MysqlVersion57: 1,
|
||||
MysqlVersion80: 2,
|
||||
MysqlVersion84: 3,
|
||||
}
|
||||
return versionOrder[backupVersion] > versionOrder[restoreVersion]
|
||||
}
|
||||
|
||||
// EscapeMysqlPassword escapes special characters for MySQL .my.cnf file format.
|
||||
// In .my.cnf, passwords with special chars should be quoted.
|
||||
// Escape backslash and quote characters.
|
||||
func EscapeMysqlPassword(password string) string {
|
||||
password = strings.ReplaceAll(password, "\\", "\\\\")
|
||||
password = strings.ReplaceAll(password, "\"", "\\\"")
|
||||
return password
|
||||
}
|
||||
|
||||
// GetMysqlVersionEnum converts a version string to MysqlVersion enum
|
||||
func GetMysqlVersionEnum(version string) MysqlVersion {
|
||||
switch version {
|
||||
case "5.7":
|
||||
return MysqlVersion57
|
||||
case "8.0":
|
||||
return MysqlVersion80
|
||||
case "8.4":
|
||||
return MysqlVersion84
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid mysql version: %s", version))
|
||||
}
|
||||
}
|
||||
|
||||
func getMysqlBasePath(
|
||||
version MysqlVersion,
|
||||
envMode env_utils.EnvMode,
|
||||
mysqlInstallDir string,
|
||||
) string {
|
||||
if envMode == env_utils.EnvModeDevelopment {
|
||||
return filepath.Join(
|
||||
mysqlInstallDir,
|
||||
fmt.Sprintf("mysql-%s", string(version)),
|
||||
"bin",
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf("/usr/local/mysql-%s/bin", string(version))
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE postgresql_databases
|
||||
ADD COLUMN include_schemas TEXT NOT NULL DEFAULT '';
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE postgresql_databases
|
||||
DROP COLUMN include_schemas;
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,11 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE s3_storages
|
||||
ADD COLUMN skip_tls_verify BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
ALTER TABLE s3_storages
|
||||
DROP COLUMN skip_tls_verify;
|
||||
-- +goose StatementEnd
|
||||
29
backend/migrations/20251213180403_add_ftp_storages.sql
Normal file
29
backend/migrations/20251213180403_add_ftp_storages.sql
Normal file
@@ -0,0 +1,29 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
CREATE TABLE ftp_storages (
|
||||
storage_id UUID PRIMARY KEY,
|
||||
host TEXT NOT NULL,
|
||||
port INTEGER NOT NULL DEFAULT 21,
|
||||
username TEXT NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
path TEXT,
|
||||
use_ssl BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
skip_tls_verify BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
passive_mode BOOLEAN NOT NULL DEFAULT TRUE
|
||||
);
|
||||
|
||||
ALTER TABLE ftp_storages
|
||||
ADD CONSTRAINT fk_ftp_storages_storage
|
||||
FOREIGN KEY (storage_id)
|
||||
REFERENCES storages (id)
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
DROP TABLE IF EXISTS ftp_storages;
|
||||
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,15 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE ftp_storages
|
||||
DROP COLUMN passive_mode;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
ALTER TABLE ftp_storages
|
||||
ADD COLUMN passive_mode BOOLEAN NOT NULL DEFAULT TRUE;
|
||||
|
||||
-- +goose StatementEnd
|
||||
23
backend/migrations/20251218123447_add_rclone_storages.sql
Normal file
23
backend/migrations/20251218123447_add_rclone_storages.sql
Normal file
@@ -0,0 +1,23 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
CREATE TABLE rclone_storages (
|
||||
storage_id UUID PRIMARY KEY,
|
||||
config_content TEXT NOT NULL,
|
||||
remote_path TEXT
|
||||
);
|
||||
|
||||
ALTER TABLE rclone_storages
|
||||
ADD CONSTRAINT fk_rclone_storages_storage
|
||||
FOREIGN KEY (storage_id)
|
||||
REFERENCES storages (id)
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
DROP TABLE IF EXISTS rclone_storages;
|
||||
|
||||
-- +goose StatementEnd
|
||||
28
backend/migrations/20251219220027_add_sftp_storages.sql
Normal file
28
backend/migrations/20251219220027_add_sftp_storages.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
|
||||
CREATE TABLE sftp_storages (
|
||||
storage_id UUID PRIMARY KEY,
|
||||
host TEXT NOT NULL,
|
||||
port INTEGER NOT NULL DEFAULT 22,
|
||||
username TEXT NOT NULL,
|
||||
password TEXT,
|
||||
private_key TEXT,
|
||||
path TEXT,
|
||||
skip_host_key_verify BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
ALTER TABLE sftp_storages
|
||||
ADD CONSTRAINT fk_sftp_storages_storage
|
||||
FOREIGN KEY (storage_id)
|
||||
REFERENCES storages (id)
|
||||
ON DELETE CASCADE DEFERRABLE INITIALLY DEFERRED;
|
||||
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
|
||||
DROP TABLE IF EXISTS sftp_storages;
|
||||
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,5 @@
|
||||
-- +goose Up
|
||||
ALTER TABLE intervals ADD COLUMN cron_expression TEXT;
|
||||
|
||||
-- +goose Down
|
||||
ALTER TABLE intervals DROP COLUMN cron_expression;
|
||||
@@ -0,0 +1,27 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
CREATE TABLE mysql_databases (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
database_id UUID REFERENCES databases(id) ON DELETE CASCADE,
|
||||
version TEXT NOT NULL,
|
||||
host TEXT NOT NULL,
|
||||
port INT NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
database TEXT,
|
||||
is_https BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
CREATE INDEX idx_mysql_databases_database_id ON mysql_databases(database_id);
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP INDEX IF EXISTS idx_mysql_databases_database_id;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
DROP TABLE IF EXISTS mysql_databases;
|
||||
-- +goose StatementEnd
|
||||
@@ -0,0 +1,28 @@
|
||||
-- +goose Up
|
||||
-- +goose StatementBegin
|
||||
CREATE TABLE mariadb_databases (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
database_id UUID REFERENCES databases(id) ON DELETE CASCADE,
|
||||
version TEXT NOT NULL,
|
||||
host TEXT NOT NULL,
|
||||
port INT NOT NULL,
|
||||
username TEXT NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
database TEXT,
|
||||
is_https BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
CREATE INDEX idx_mariadb_databases_database_id ON mariadb_databases(database_id);
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose Down
|
||||
-- +goose StatementBegin
|
||||
DROP INDEX IF EXISTS idx_mariadb_databases_database_id;
|
||||
-- +goose StatementEnd
|
||||
|
||||
-- +goose StatementBegin
|
||||
DROP TABLE IF EXISTS mariadb_databases;
|
||||
-- +goose StatementEnd
|
||||
|
||||
4
backend/tools/.gitignore
vendored
4
backend/tools/.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
postgresql
|
||||
downloads
|
||||
mysql
|
||||
downloads
|
||||
mariadb
|
||||
@@ -5,7 +5,7 @@ set -e # Exit on any error
|
||||
# Ensure non-interactive mode for apt
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
echo "Installing PostgreSQL client tools versions 12-18 for Linux (Debian/Ubuntu)..."
|
||||
echo "Installing PostgreSQL and MySQL client tools for Linux (Debian/Ubuntu)..."
|
||||
echo
|
||||
|
||||
# Check if running on supported system
|
||||
@@ -22,19 +22,27 @@ else
|
||||
echo "This script requires sudo privileges to install packages."
|
||||
fi
|
||||
|
||||
# Create postgresql directory
|
||||
# Create directories
|
||||
mkdir -p postgresql
|
||||
mkdir -p mysql
|
||||
|
||||
# Get absolute path
|
||||
# Get absolute paths
|
||||
POSTGRES_DIR="$(pwd)/postgresql"
|
||||
MYSQL_DIR="$(pwd)/mysql"
|
||||
|
||||
echo "Installing PostgreSQL client tools to: $POSTGRES_DIR"
|
||||
echo "Installing MySQL client tools to: $MYSQL_DIR"
|
||||
echo
|
||||
|
||||
# ========== PostgreSQL Installation ==========
|
||||
echo "========================================"
|
||||
echo "Installing PostgreSQL client tools (versions 12-18)..."
|
||||
echo "========================================"
|
||||
|
||||
# Add PostgreSQL official APT repository
|
||||
echo "Adding PostgreSQL official APT repository..."
|
||||
$SUDO apt-get update -qq -y
|
||||
$SUDO apt-get install -y -qq wget ca-certificates
|
||||
$SUDO apt-get install -y -qq wget ca-certificates gnupg lsb-release
|
||||
|
||||
# Add GPG key
|
||||
wget --quiet -O - https://www.postgresql.org/media/keys/ACCC4CF8.asc | $SUDO apt-key add - 2>/dev/null
|
||||
@@ -46,10 +54,10 @@ echo "deb http://apt.postgresql.org/pub/repos/apt/ $(lsb_release -cs)-pgdg main"
|
||||
echo "Updating package list..."
|
||||
$SUDO apt-get update -qq -y
|
||||
|
||||
# Install client tools for each version
|
||||
versions="12 13 14 15 16 17 18"
|
||||
# Install PostgreSQL client tools for each version
|
||||
pg_versions="12 13 14 15 16 17 18"
|
||||
|
||||
for version in $versions; do
|
||||
for version in $pg_versions; do
|
||||
echo "Installing PostgreSQL $version client tools..."
|
||||
|
||||
# Install client tools only
|
||||
@@ -85,22 +93,184 @@ for version in $versions; do
|
||||
echo
|
||||
done
|
||||
|
||||
echo "Installation completed!"
|
||||
echo "PostgreSQL client tools are available in: $POSTGRES_DIR"
|
||||
# ========== MySQL Installation ==========
|
||||
echo "========================================"
|
||||
echo "Installing MySQL client tools (versions 5.7, 8.0, 8.4)..."
|
||||
echo "========================================"
|
||||
|
||||
# Download and extract MySQL client tools
|
||||
mysql_versions="5.7 8.0 8.4"
|
||||
|
||||
for version in $mysql_versions; do
|
||||
echo "Installing MySQL $version client tools..."
|
||||
|
||||
version_dir="$MYSQL_DIR/mysql-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
|
||||
# Download MySQL client tools from official CDN
|
||||
# Note: 5.7 is in Downloads, 8.0 and 8.4 specific versions are in archives
|
||||
case $version in
|
||||
"5.7")
|
||||
MYSQL_URL="https://cdn.mysql.com/Downloads/MySQL-5.7/mysql-5.7.44-linux-glibc2.12-x86_64.tar.gz"
|
||||
;;
|
||||
"8.0")
|
||||
MYSQL_URL="https://cdn.mysql.com/archives/mysql-8.0/mysql-8.0.40-linux-glibc2.17-x86_64-minimal.tar.xz"
|
||||
;;
|
||||
"8.4")
|
||||
MYSQL_URL="https://cdn.mysql.com/archives/mysql-8.4/mysql-8.4.3-linux-glibc2.17-x86_64-minimal.tar.xz"
|
||||
;;
|
||||
esac
|
||||
|
||||
TEMP_DIR="/tmp/mysql_install_$version"
|
||||
mkdir -p "$TEMP_DIR"
|
||||
cd "$TEMP_DIR"
|
||||
|
||||
echo " Downloading MySQL $version..."
|
||||
wget -q "$MYSQL_URL" -O "mysql-$version.tar.gz" || wget -q "$MYSQL_URL" -O "mysql-$version.tar.xz"
|
||||
|
||||
echo " Extracting MySQL $version..."
|
||||
if [[ "$MYSQL_URL" == *.xz ]]; then
|
||||
tar -xJf "mysql-$version.tar.xz" 2>/dev/null || tar -xJf "mysql-$version.tar.gz" 2>/dev/null
|
||||
else
|
||||
tar -xzf "mysql-$version.tar.gz" 2>/dev/null || tar -xzf "mysql-$version.tar.xz" 2>/dev/null
|
||||
fi
|
||||
|
||||
# Find extracted directory
|
||||
EXTRACTED_DIR=$(ls -d mysql-*/ 2>/dev/null | head -1)
|
||||
|
||||
if [ -d "$EXTRACTED_DIR" ] && [ -f "$EXTRACTED_DIR/bin/mysqldump" ]; then
|
||||
# Copy client binaries
|
||||
cp "$EXTRACTED_DIR/bin/mysql" "$version_dir/bin/" 2>/dev/null || true
|
||||
cp "$EXTRACTED_DIR/bin/mysqldump" "$version_dir/bin/" 2>/dev/null || true
|
||||
chmod +x "$version_dir/bin/"*
|
||||
|
||||
echo " MySQL $version client tools installed successfully"
|
||||
else
|
||||
echo " Warning: Could not extract MySQL $version binaries"
|
||||
echo " You may need to install MySQL $version client tools manually"
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
cd - >/dev/null
|
||||
rm -rf "$TEMP_DIR"
|
||||
echo
|
||||
done
|
||||
|
||||
# ========== MariaDB Installation ==========
|
||||
echo "========================================"
|
||||
echo "Installing MariaDB client tools (versions 10.6 and 12.1)..."
|
||||
echo "========================================"
|
||||
|
||||
# MariaDB uses two client versions:
|
||||
# - 10.6 (legacy): For older servers (5.5, 10.1) that don't have generation_expression column
|
||||
# - 12.1 (modern): For newer servers (10.2+)
|
||||
|
||||
MARIADB_DIR="$(pwd)/mariadb"
|
||||
|
||||
echo "Installing MariaDB client tools to: $MARIADB_DIR"
|
||||
|
||||
# Install dependencies
|
||||
$SUDO apt-get install -y -qq apt-transport-https curl
|
||||
|
||||
# MariaDB versions to install with their URLs
|
||||
declare -A MARIADB_URLS=(
|
||||
["10.6"]="https://archive.mariadb.org/mariadb-10.6.21/bintar-linux-systemd-x86_64/mariadb-10.6.21-linux-systemd-x86_64.tar.gz"
|
||||
["12.1"]="https://archive.mariadb.org/mariadb-12.1.2/bintar-linux-systemd-x86_64/mariadb-12.1.2-linux-systemd-x86_64.tar.gz"
|
||||
)
|
||||
|
||||
mariadb_versions="10.6 12.1"
|
||||
|
||||
for version in $mariadb_versions; do
|
||||
echo "Installing MariaDB $version client tools..."
|
||||
|
||||
version_dir="$MARIADB_DIR/mariadb-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
|
||||
# Skip if already exists
|
||||
if [ -f "$version_dir/bin/mariadb-dump" ]; then
|
||||
echo "MariaDB $version already installed, skipping..."
|
||||
continue
|
||||
fi
|
||||
|
||||
url=${MARIADB_URLS[$version]}
|
||||
|
||||
TEMP_DIR="/tmp/mariadb_install_$version"
|
||||
mkdir -p "$TEMP_DIR"
|
||||
cd "$TEMP_DIR"
|
||||
|
||||
echo " Downloading MariaDB $version from official archive..."
|
||||
wget -q "$url" -O "mariadb-$version.tar.gz" || {
|
||||
echo " Warning: Could not download MariaDB $version binaries"
|
||||
cd - >/dev/null
|
||||
rm -rf "$TEMP_DIR"
|
||||
continue
|
||||
}
|
||||
|
||||
echo " Extracting MariaDB $version..."
|
||||
tar -xzf "mariadb-$version.tar.gz"
|
||||
EXTRACTED_DIR=$(ls -d mariadb-*/ 2>/dev/null | head -1)
|
||||
|
||||
if [ -d "$EXTRACTED_DIR" ] && [ -f "$EXTRACTED_DIR/bin/mariadb-dump" ]; then
|
||||
cp "$EXTRACTED_DIR/bin/mariadb" "$version_dir/bin/" 2>/dev/null || true
|
||||
cp "$EXTRACTED_DIR/bin/mariadb-dump" "$version_dir/bin/" 2>/dev/null || true
|
||||
chmod +x "$version_dir/bin/"*
|
||||
echo " MariaDB $version client tools installed successfully"
|
||||
else
|
||||
echo " Warning: Could not extract MariaDB $version binaries"
|
||||
fi
|
||||
|
||||
# Cleanup
|
||||
cd - >/dev/null
|
||||
rm -rf "$TEMP_DIR"
|
||||
echo
|
||||
done
|
||||
|
||||
echo
|
||||
|
||||
# List installed versions
|
||||
echo "========================================"
|
||||
echo "Installation completed!"
|
||||
echo "========================================"
|
||||
echo
|
||||
echo "PostgreSQL client tools are available in: $POSTGRES_DIR"
|
||||
echo "MySQL client tools are available in: $MYSQL_DIR"
|
||||
echo "MariaDB client tools are available in: $MARIADB_DIR"
|
||||
echo
|
||||
|
||||
# List installed PostgreSQL versions
|
||||
echo "Installed PostgreSQL client versions:"
|
||||
for version in $versions; do
|
||||
for version in $pg_versions; do
|
||||
version_dir="$POSTGRES_DIR/postgresql-$version"
|
||||
if [ -f "$version_dir/bin/pg_dump" ]; then
|
||||
echo " postgresql-$version: $version_dir/bin/"
|
||||
# Verify the correct version
|
||||
version_output=$("$version_dir/bin/pg_dump" --version 2>/dev/null | grep -o "pg_dump (PostgreSQL) [0-9]\+\.[0-9]\+")
|
||||
echo " Version check: $version_output"
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Usage example:"
|
||||
echo " $POSTGRES_DIR/postgresql-15/bin/pg_dump --version"
|
||||
echo "Installed MySQL client versions:"
|
||||
for version in $mysql_versions; do
|
||||
version_dir="$MYSQL_DIR/mysql-$version"
|
||||
if [ -f "$version_dir/bin/mysqldump" ]; then
|
||||
echo " mysql-$version: $version_dir/bin/"
|
||||
version_output=$("$version_dir/bin/mysqldump" --version 2>/dev/null | head -1)
|
||||
echo " Version check: $version_output"
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Installed MariaDB client versions:"
|
||||
for version in $mariadb_versions; do
|
||||
version_dir="$MARIADB_DIR/mariadb-$version"
|
||||
if [ -f "$version_dir/bin/mariadb-dump" ]; then
|
||||
echo " mariadb-$version: $version_dir/bin/"
|
||||
version_output=$("$version_dir/bin/mariadb-dump" --version 2>/dev/null | head -1)
|
||||
echo " Version check: $version_output"
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Usage examples:"
|
||||
echo " $POSTGRES_DIR/postgresql-15/bin/pg_dump --version"
|
||||
echo " $MYSQL_DIR/mysql-8.0/bin/mysqldump --version"
|
||||
echo " $MARIADB_DIR/mariadb-12.1/bin/mariadb-dump --version"
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
set -e # Exit on any error
|
||||
|
||||
echo "Installing PostgreSQL client tools versions 12-18 for MacOS..."
|
||||
echo "Installing PostgreSQL and MySQL client tools for MacOS..."
|
||||
echo
|
||||
|
||||
# Check if Homebrew is installed
|
||||
@@ -12,13 +12,16 @@ if ! command -v brew &> /dev/null; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create postgresql directory
|
||||
# Create directories
|
||||
mkdir -p postgresql
|
||||
mkdir -p mysql
|
||||
|
||||
# Get absolute path
|
||||
# Get absolute paths
|
||||
POSTGRES_DIR="$(pwd)/postgresql"
|
||||
MYSQL_DIR="$(pwd)/mysql"
|
||||
|
||||
echo "Installing PostgreSQL client tools to: $POSTGRES_DIR"
|
||||
echo "Installing MySQL client tools to: $MYSQL_DIR"
|
||||
echo
|
||||
|
||||
# Update Homebrew
|
||||
@@ -27,7 +30,12 @@ brew update
|
||||
|
||||
# Install build dependencies
|
||||
echo "Installing build dependencies..."
|
||||
brew install wget openssl readline zlib
|
||||
brew install wget openssl readline zlib cmake
|
||||
|
||||
# ========== PostgreSQL Installation ==========
|
||||
echo "========================================"
|
||||
echo "Building PostgreSQL client tools (versions 12-18)..."
|
||||
echo "========================================"
|
||||
|
||||
# PostgreSQL source URLs
|
||||
declare -A PG_URLS=(
|
||||
@@ -41,7 +49,7 @@ declare -A PG_URLS=(
|
||||
)
|
||||
|
||||
# Create temporary build directory
|
||||
BUILD_DIR="/tmp/postgresql_build_$$"
|
||||
BUILD_DIR="/tmp/db_tools_build_$$"
|
||||
mkdir -p "$BUILD_DIR"
|
||||
|
||||
echo "Using temporary build directory: $BUILD_DIR"
|
||||
@@ -107,10 +115,10 @@ build_postgresql_client() {
|
||||
echo
|
||||
}
|
||||
|
||||
# Build each version
|
||||
versions="12 13 14 15 16 17 18"
|
||||
# Build each PostgreSQL version
|
||||
pg_versions="12 13 14 15 16 17 18"
|
||||
|
||||
for version in $versions; do
|
||||
for version in $pg_versions; do
|
||||
url=${PG_URLS[$version]}
|
||||
if [ -n "$url" ]; then
|
||||
build_postgresql_client "$version" "$url"
|
||||
@@ -119,17 +127,181 @@ for version in $versions; do
|
||||
fi
|
||||
done
|
||||
|
||||
# ========== MySQL Installation ==========
|
||||
echo "========================================"
|
||||
echo "Installing MySQL client tools (versions 5.7, 8.0, 8.4)..."
|
||||
echo "========================================"
|
||||
|
||||
# Detect architecture
|
||||
ARCH=$(uname -m)
|
||||
if [ "$ARCH" = "arm64" ]; then
|
||||
MYSQL_ARCH="arm64"
|
||||
else
|
||||
MYSQL_ARCH="x86_64"
|
||||
fi
|
||||
|
||||
# MySQL download URLs for macOS (using CDN)
|
||||
# Note: 5.7 is in Downloads, 8.0 and 8.4 specific versions are in archives
|
||||
declare -A MYSQL_URLS=(
|
||||
["5.7"]="https://cdn.mysql.com/Downloads/MySQL-5.7/mysql-5.7.44-macos10.14-x86_64.tar.gz"
|
||||
["8.0"]="https://cdn.mysql.com/archives/mysql-8.0/mysql-8.0.40-macos14-${MYSQL_ARCH}.tar.gz"
|
||||
["8.4"]="https://cdn.mysql.com/archives/mysql-8.4/mysql-8.4.3-macos14-${MYSQL_ARCH}.tar.gz"
|
||||
)
|
||||
|
||||
# Function to install MySQL client tools
|
||||
install_mysql_client() {
|
||||
local version=$1
|
||||
local url=$2
|
||||
local version_dir="$MYSQL_DIR/mysql-$version"
|
||||
|
||||
echo "Installing MySQL $version client tools..."
|
||||
|
||||
# Skip if already exists
|
||||
if [ -f "$version_dir/bin/mysqldump" ]; then
|
||||
echo "MySQL $version already installed, skipping..."
|
||||
return
|
||||
fi
|
||||
|
||||
mkdir -p "$version_dir/bin"
|
||||
cd "$BUILD_DIR"
|
||||
|
||||
# Download
|
||||
echo " Downloading MySQL $version..."
|
||||
wget -q "$url" -O "mysql-$version.tar.gz" || {
|
||||
echo " Warning: Could not download MySQL $version for $MYSQL_ARCH"
|
||||
echo " You may need to install MySQL $version client tools manually"
|
||||
return
|
||||
}
|
||||
|
||||
# Extract
|
||||
echo " Extracting MySQL $version..."
|
||||
tar -xzf "mysql-$version.tar.gz"
|
||||
|
||||
# Find extracted directory
|
||||
EXTRACTED_DIR=$(ls -d mysql-*/ 2>/dev/null | head -1)
|
||||
|
||||
if [ -d "$EXTRACTED_DIR" ] && [ -f "$EXTRACTED_DIR/bin/mysqldump" ]; then
|
||||
# Copy client binaries
|
||||
cp "$EXTRACTED_DIR/bin/mysql" "$version_dir/bin/" 2>/dev/null || true
|
||||
cp "$EXTRACTED_DIR/bin/mysqldump" "$version_dir/bin/" 2>/dev/null || true
|
||||
chmod +x "$version_dir/bin/"*
|
||||
|
||||
echo " MySQL $version client tools installed successfully"
|
||||
|
||||
# Test the installation
|
||||
local mysql_version=$("$version_dir/bin/mysqldump" --version 2>/dev/null | head -1)
|
||||
echo " Verified: $mysql_version"
|
||||
else
|
||||
echo " Warning: Could not extract MySQL $version binaries"
|
||||
echo " You may need to install MySQL $version client tools manually"
|
||||
fi
|
||||
|
||||
# Clean up
|
||||
rm -rf "mysql-$version.tar.gz" mysql-*/
|
||||
|
||||
echo
|
||||
}
|
||||
|
||||
# Install each MySQL version
|
||||
mysql_versions="5.7 8.0 8.4"
|
||||
|
||||
for version in $mysql_versions; do
|
||||
url=${MYSQL_URLS[$version]}
|
||||
if [ -n "$url" ]; then
|
||||
install_mysql_client "$version" "$url"
|
||||
else
|
||||
echo "Warning: No URL defined for MySQL $version"
|
||||
fi
|
||||
done
|
||||
|
||||
# ========== MariaDB Installation ==========
|
||||
echo "========================================"
|
||||
echo "Installing MariaDB client tools (versions 10.6 and 12.1)..."
|
||||
echo "========================================"
|
||||
|
||||
# MariaDB uses two client versions:
|
||||
# - 10.6 (legacy): For older servers (5.5, 10.1) that don't have generation_expression column
|
||||
# - 12.1 (modern): For newer servers (10.2+)
|
||||
|
||||
MARIADB_DIR="$(pwd)/mariadb"
|
||||
|
||||
echo "Installing MariaDB client tools to: $MARIADB_DIR"
|
||||
|
||||
# MariaDB versions to install
|
||||
# Note: MariaDB doesn't provide pre-built macOS binaries for older versions
|
||||
# We install via Homebrew and use the same version for both (Homebrew only has latest)
|
||||
# For production macOS use, the latest client should work with older servers for basic operations
|
||||
|
||||
mariadb_versions="10.6 12.1"
|
||||
|
||||
# Install MariaDB via Homebrew first (we'll use it for the modern version)
|
||||
echo " Installing MariaDB via Homebrew..."
|
||||
brew install mariadb 2>/dev/null || {
|
||||
echo " Warning: Could not install mariadb via Homebrew"
|
||||
brew install mariadb-connector-c 2>/dev/null || true
|
||||
}
|
||||
|
||||
# Find Homebrew MariaDB path
|
||||
BREW_MARIADB=""
|
||||
if [ -f "/opt/homebrew/bin/mariadb-dump" ]; then
|
||||
BREW_MARIADB="/opt/homebrew/bin"
|
||||
elif [ -f "/usr/local/bin/mariadb-dump" ]; then
|
||||
BREW_MARIADB="/usr/local/bin"
|
||||
else
|
||||
BREW_PREFIX=$(brew --prefix mariadb 2>/dev/null || echo "")
|
||||
if [ -n "$BREW_PREFIX" ] && [ -f "$BREW_PREFIX/bin/mariadb-dump" ]; then
|
||||
BREW_MARIADB="$BREW_PREFIX/bin"
|
||||
fi
|
||||
fi
|
||||
|
||||
for version in $mariadb_versions; do
|
||||
echo "Setting up MariaDB $version client tools..."
|
||||
|
||||
version_dir="$MARIADB_DIR/mariadb-$version"
|
||||
mkdir -p "$version_dir/bin"
|
||||
|
||||
# Skip if already exists
|
||||
if [ -f "$version_dir/bin/mariadb-dump" ]; then
|
||||
echo " MariaDB $version already installed, skipping..."
|
||||
continue
|
||||
fi
|
||||
|
||||
if [ -n "$BREW_MARIADB" ]; then
|
||||
# Link from Homebrew
|
||||
# Note: On macOS, we use the same Homebrew version for both paths
|
||||
# The Homebrew version (latest) should handle both old and new servers
|
||||
ln -sf "$BREW_MARIADB/mariadb" "$version_dir/bin/mariadb"
|
||||
ln -sf "$BREW_MARIADB/mariadb-dump" "$version_dir/bin/mariadb-dump"
|
||||
echo " MariaDB $version client tools linked from Homebrew"
|
||||
|
||||
# Test the installation
|
||||
mariadb_ver=$("$version_dir/bin/mariadb-dump" --version 2>/dev/null | head -1)
|
||||
echo " Verified: $mariadb_ver"
|
||||
else
|
||||
echo " Warning: Could not find MariaDB binaries for $version"
|
||||
echo " Please install MariaDB manually: brew install mariadb"
|
||||
fi
|
||||
echo
|
||||
done
|
||||
|
||||
echo
|
||||
|
||||
# Clean up build directory
|
||||
echo "Cleaning up build directory..."
|
||||
rm -rf "$BUILD_DIR"
|
||||
|
||||
echo "========================================"
|
||||
echo "Installation completed!"
|
||||
echo "========================================"
|
||||
echo
|
||||
echo "PostgreSQL client tools are available in: $POSTGRES_DIR"
|
||||
echo "MySQL client tools are available in: $MYSQL_DIR"
|
||||
echo "MariaDB client tools are available in: $MARIADB_DIR"
|
||||
echo
|
||||
|
||||
# List installed versions
|
||||
# List installed PostgreSQL versions
|
||||
echo "Installed PostgreSQL client versions:"
|
||||
for version in $versions; do
|
||||
for version in $pg_versions; do
|
||||
version_dir="$POSTGRES_DIR/postgresql-$version"
|
||||
if [ -f "$version_dir/bin/pg_dump" ]; then
|
||||
pg_version=$("$version_dir/bin/pg_dump" --version | cut -d' ' -f3)
|
||||
@@ -138,8 +310,34 @@ for version in $versions; do
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Usage example:"
|
||||
echo " $POSTGRES_DIR/postgresql-15/bin/pg_dump --version"
|
||||
echo "Installed MySQL client versions:"
|
||||
for version in $mysql_versions; do
|
||||
version_dir="$MYSQL_DIR/mysql-$version"
|
||||
if [ -f "$version_dir/bin/mysqldump" ]; then
|
||||
mysql_version=$("$version_dir/bin/mysqldump" --version 2>/dev/null | head -1)
|
||||
echo " mysql-$version: $version_dir/bin/"
|
||||
echo " $mysql_version"
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "To add a specific version to your PATH temporarily:"
|
||||
echo " export PATH=\"$POSTGRES_DIR/postgresql-15/bin:\$PATH\""
|
||||
echo "Installed MariaDB client versions:"
|
||||
for version in $mariadb_versions; do
|
||||
version_dir="$MARIADB_DIR/mariadb-$version"
|
||||
if [ -f "$version_dir/bin/mariadb-dump" ]; then
|
||||
mariadb_ver=$("$version_dir/bin/mariadb-dump" --version 2>/dev/null | head -1)
|
||||
echo " mariadb-$version: $version_dir/bin/"
|
||||
echo " $mariadb_ver"
|
||||
fi
|
||||
done
|
||||
|
||||
echo
|
||||
echo "Usage examples:"
|
||||
echo " $POSTGRES_DIR/postgresql-15/bin/pg_dump --version"
|
||||
echo " $MYSQL_DIR/mysql-8.0/bin/mysqldump --version"
|
||||
echo " $MARIADB_DIR/mariadb-12.1/bin/mariadb-dump --version"
|
||||
echo
|
||||
echo "To add specific versions to your PATH temporarily:"
|
||||
echo " export PATH=\"$POSTGRES_DIR/postgresql-15/bin:\$PATH\""
|
||||
echo " export PATH=\"$MYSQL_DIR/mysql-8.0/bin:\$PATH\""
|
||||
echo " export PATH=\"$MARIADB_DIR/mariadb-12.1/bin:\$PATH\""
|
||||
@@ -1,22 +1,37 @@
|
||||
@echo off
|
||||
setlocal enabledelayedexpansion
|
||||
|
||||
echo Downloading and installing PostgreSQL versions 12-18 for Windows...
|
||||
echo Downloading and installing PostgreSQL and MySQL client tools for Windows...
|
||||
echo.
|
||||
|
||||
:: Create downloads and postgresql directories if they don't exist
|
||||
:: Create directories if they don't exist
|
||||
if not exist "downloads" mkdir downloads
|
||||
if not exist "postgresql" mkdir postgresql
|
||||
if not exist "mysql" mkdir mysql
|
||||
if not exist "mariadb" mkdir mariadb
|
||||
|
||||
:: Get the absolute path to the postgresql directory
|
||||
:: Get the absolute paths
|
||||
set "POSTGRES_DIR=%cd%\postgresql"
|
||||
set "MYSQL_DIR=%cd%\mysql"
|
||||
set "MARIADB_DIR=%cd%\mariadb"
|
||||
|
||||
echo PostgreSQL will be installed to: %POSTGRES_DIR%
|
||||
echo MySQL will be installed to: %MYSQL_DIR%
|
||||
echo MariaDB will be installed to: %MARIADB_DIR%
|
||||
echo.
|
||||
|
||||
cd downloads
|
||||
|
||||
:: ========== PostgreSQL Installation ==========
|
||||
echo ========================================
|
||||
echo Installing PostgreSQL client tools (versions 12-18)...
|
||||
echo ========================================
|
||||
echo.
|
||||
|
||||
:: PostgreSQL download URLs for Windows x64
|
||||
set "BASE_URL=https://get.enterprisedb.com/postgresql"
|
||||
|
||||
:: Define versions and their corresponding download URLs
|
||||
:: Define PostgreSQL versions and their corresponding download URLs
|
||||
set "PG12_URL=%BASE_URL%/postgresql-12.20-1-windows-x64.exe"
|
||||
set "PG13_URL=%BASE_URL%/postgresql-13.16-1-windows-x64.exe"
|
||||
set "PG14_URL=%BASE_URL%/postgresql-14.13-1-windows-x64.exe"
|
||||
@@ -25,11 +40,11 @@ set "PG16_URL=%BASE_URL%/postgresql-16.4-1-windows-x64.exe"
|
||||
set "PG17_URL=%BASE_URL%/postgresql-17.0-1-windows-x64.exe"
|
||||
set "PG18_URL=%BASE_URL%/postgresql-18.0-1-windows-x64.exe"
|
||||
|
||||
:: Array of versions
|
||||
set "versions=12 13 14 15 16 17 18"
|
||||
:: PostgreSQL versions
|
||||
set "pg_versions=12 13 14 15 16 17 18"
|
||||
|
||||
:: Download and install each version
|
||||
for %%v in (%versions%) do (
|
||||
:: Download and install each PostgreSQL version
|
||||
for %%v in (%pg_versions%) do (
|
||||
echo Processing PostgreSQL %%v...
|
||||
set "filename=postgresql-%%v-windows-x64.exe"
|
||||
set "install_dir=%POSTGRES_DIR%\postgresql-%%v"
|
||||
@@ -45,7 +60,7 @@ for %%v in (%versions%) do (
|
||||
|
||||
if !errorlevel! neq 0 (
|
||||
echo Failed to download PostgreSQL %%v
|
||||
goto :next_version
|
||||
goto :next_pg_version
|
||||
)
|
||||
echo PostgreSQL %%v downloaded successfully
|
||||
) else (
|
||||
@@ -83,13 +98,238 @@ for %%v in (%versions%) do (
|
||||
)
|
||||
)
|
||||
|
||||
:next_version
|
||||
:next_pg_version
|
||||
echo.
|
||||
)
|
||||
|
||||
:: ========== MySQL Installation ==========
|
||||
echo ========================================
|
||||
echo Installing MySQL client tools (versions 5.7, 8.0, 8.4)...
|
||||
echo ========================================
|
||||
echo.
|
||||
|
||||
:: MySQL download URLs for Windows x64 (ZIP archives) - using CDN
|
||||
:: Note: 5.7 is in Downloads, 8.0 and 8.4 specific versions are in archives
|
||||
set "MYSQL57_URL=https://cdn.mysql.com/Downloads/MySQL-5.7/mysql-5.7.44-winx64.zip"
|
||||
set "MYSQL80_URL=https://cdn.mysql.com/archives/mysql-8.0/mysql-8.0.40-winx64.zip"
|
||||
set "MYSQL84_URL=https://cdn.mysql.com/archives/mysql-8.4/mysql-8.4.3-winx64.zip"
|
||||
|
||||
:: MySQL versions
|
||||
set "mysql_versions=5.7 8.0 8.4"
|
||||
|
||||
:: Download and install each MySQL version
|
||||
for %%v in (%mysql_versions%) do (
|
||||
echo Processing MySQL %%v...
|
||||
set "version_underscore=%%v"
|
||||
set "version_underscore=!version_underscore:.=!"
|
||||
set "filename=mysql-%%v-winx64.zip"
|
||||
set "install_dir=%MYSQL_DIR%\mysql-%%v"
|
||||
|
||||
:: Build the URL variable name and get its value
|
||||
call set "current_url=%%MYSQL!version_underscore!_URL%%"
|
||||
|
||||
:: Check if already installed
|
||||
if exist "!install_dir!\bin\mysqldump.exe" (
|
||||
echo MySQL %%v already installed, skipping...
|
||||
) else (
|
||||
:: Download if not exists
|
||||
if not exist "!filename!" (
|
||||
echo Downloading MySQL %%v...
|
||||
echo Downloading from: !current_url!
|
||||
curl -L -o "!filename!" -A "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" "!current_url!"
|
||||
if !errorlevel! neq 0 (
|
||||
echo ERROR: Download request failed
|
||||
goto :next_mysql_version
|
||||
)
|
||||
if not exist "!filename!" (
|
||||
echo ERROR: Download failed - file not created
|
||||
goto :next_mysql_version
|
||||
)
|
||||
for %%s in ("!filename!") do if %%~zs LSS 1000000 (
|
||||
echo ERROR: Download failed - file too small, likely error page
|
||||
del "!filename!" 2>nul
|
||||
goto :next_mysql_version
|
||||
)
|
||||
echo MySQL %%v downloaded successfully
|
||||
) else (
|
||||
echo MySQL %%v already downloaded
|
||||
)
|
||||
|
||||
:: Verify file exists before extraction
|
||||
if not exist "!filename!" (
|
||||
echo Download file not found, skipping extraction...
|
||||
goto :next_mysql_version
|
||||
)
|
||||
|
||||
:: Extract MySQL
|
||||
echo Extracting MySQL %%v...
|
||||
mkdir "!install_dir!" 2>nul
|
||||
|
||||
powershell -Command "Expand-Archive -Path '!filename!' -DestinationPath '!install_dir!_temp' -Force"
|
||||
|
||||
:: Move files from nested directory to install_dir
|
||||
for /d %%d in ("!install_dir!_temp\mysql-*") do (
|
||||
if exist "%%d\bin\mysqldump.exe" (
|
||||
mkdir "!install_dir!\bin" 2>nul
|
||||
copy "%%d\bin\mysql.exe" "!install_dir!\bin\" >nul 2>&1
|
||||
copy "%%d\bin\mysqldump.exe" "!install_dir!\bin\" >nul 2>&1
|
||||
)
|
||||
)
|
||||
|
||||
:: Cleanup temp directory
|
||||
rmdir /s /q "!install_dir!_temp" 2>nul
|
||||
|
||||
:: Verify installation
|
||||
if exist "!install_dir!\bin\mysqldump.exe" (
|
||||
echo MySQL %%v client tools installed successfully
|
||||
) else (
|
||||
echo Failed to install MySQL %%v - mysqldump.exe not found
|
||||
)
|
||||
)
|
||||
|
||||
:next_mysql_version
|
||||
echo.
|
||||
)
|
||||
|
||||
:: ========== MariaDB Installation ==========
|
||||
echo ========================================
|
||||
echo Installing MariaDB client tools (versions 10.6 and 12.1)...
|
||||
echo ========================================
|
||||
echo.
|
||||
|
||||
:: MariaDB uses two client versions:
|
||||
:: - 10.6 (legacy): For older servers (5.5, 10.1) that don't have generation_expression column
|
||||
:: - 12.1 (modern): For newer servers (10.2+)
|
||||
|
||||
:: MariaDB download URLs
|
||||
set "MARIADB106_URL=https://archive.mariadb.org/mariadb-10.6.21/winx64-packages/mariadb-10.6.21-winx64.zip"
|
||||
set "MARIADB121_URL=https://archive.mariadb.org/mariadb-12.1.2/winx64-packages/mariadb-12.1.2-winx64.zip"
|
||||
|
||||
:: MariaDB versions to install
|
||||
set "mariadb_versions=10.6 12.1"
|
||||
|
||||
:: Download and install each MariaDB version
|
||||
for %%v in (%mariadb_versions%) do (
|
||||
echo Processing MariaDB %%v...
|
||||
set "version_underscore=%%v"
|
||||
set "version_underscore=!version_underscore:.=!"
|
||||
set "mariadb_install_dir=%MARIADB_DIR%\mariadb-%%v"
|
||||
|
||||
:: Build the URL variable name and get its value
|
||||
call set "current_url=%%MARIADB!version_underscore!_URL%%"
|
||||
|
||||
:: Check if already installed
|
||||
if exist "!mariadb_install_dir!\bin\mariadb-dump.exe" (
|
||||
echo MariaDB %%v already installed, skipping...
|
||||
) else (
|
||||
:: Extract version number from URL for filename
|
||||
for %%u in ("!current_url!") do set "mariadb_filename=%%~nxu"
|
||||
|
||||
if not exist "!mariadb_filename!" (
|
||||
echo Downloading MariaDB %%v...
|
||||
echo Downloading from: !current_url!
|
||||
curl -L -o "!mariadb_filename!" -A "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36" "!current_url!"
|
||||
if !errorlevel! neq 0 (
|
||||
echo ERROR: Download request failed
|
||||
goto :next_mariadb_version
|
||||
)
|
||||
if not exist "!mariadb_filename!" (
|
||||
echo ERROR: Download failed - file not created
|
||||
goto :next_mariadb_version
|
||||
)
|
||||
for %%s in ("!mariadb_filename!") do if %%~zs LSS 1000000 (
|
||||
echo ERROR: Download failed - file too small, likely error page
|
||||
del "!mariadb_filename!" 2>nul
|
||||
goto :next_mariadb_version
|
||||
)
|
||||
echo MariaDB %%v downloaded successfully
|
||||
) else (
|
||||
echo MariaDB %%v already downloaded
|
||||
)
|
||||
|
||||
:: Verify file exists before extraction
|
||||
if not exist "!mariadb_filename!" (
|
||||
echo Download file not found, skipping extraction...
|
||||
goto :next_mariadb_version
|
||||
)
|
||||
|
||||
:: Extract MariaDB
|
||||
echo Extracting MariaDB %%v...
|
||||
mkdir "!mariadb_install_dir!" 2>nul
|
||||
mkdir "!mariadb_install_dir!\bin" 2>nul
|
||||
|
||||
powershell -Command "Expand-Archive -Path '!mariadb_filename!' -DestinationPath '!mariadb_install_dir!_temp' -Force"
|
||||
|
||||
:: Move files from nested directory to install_dir
|
||||
for /d %%d in ("!mariadb_install_dir!_temp\mariadb-*") do (
|
||||
if exist "%%d\bin\mariadb-dump.exe" (
|
||||
copy "%%d\bin\mariadb.exe" "!mariadb_install_dir!\bin\" >nul 2>&1
|
||||
copy "%%d\bin\mariadb-dump.exe" "!mariadb_install_dir!\bin\" >nul 2>&1
|
||||
)
|
||||
)
|
||||
|
||||
:: Cleanup temp directory
|
||||
rmdir /s /q "!mariadb_install_dir!_temp" 2>nul
|
||||
|
||||
:: Verify installation
|
||||
if exist "!mariadb_install_dir!\bin\mariadb-dump.exe" (
|
||||
echo MariaDB %%v client tools installed successfully
|
||||
) else (
|
||||
echo Failed to install MariaDB %%v - mariadb-dump.exe not found
|
||||
)
|
||||
)
|
||||
|
||||
:next_mariadb_version
|
||||
echo.
|
||||
)
|
||||
|
||||
:skip_mariadb
|
||||
echo.
|
||||
|
||||
cd ..
|
||||
|
||||
echo.
|
||||
echo ========================================
|
||||
echo Installation process completed!
|
||||
echo ========================================
|
||||
echo.
|
||||
echo PostgreSQL versions are installed in: %POSTGRES_DIR%
|
||||
echo MySQL versions are installed in: %MYSQL_DIR%
|
||||
echo MariaDB is installed in: %MARIADB_DIR%
|
||||
echo.
|
||||
|
||||
:: List installed PostgreSQL versions
|
||||
echo Installed PostgreSQL client versions:
|
||||
for %%v in (%pg_versions%) do (
|
||||
set "version_dir=%POSTGRES_DIR%\postgresql-%%v"
|
||||
if exist "!version_dir!\bin\pg_dump.exe" (
|
||||
echo postgresql-%%v: !version_dir!\bin\
|
||||
)
|
||||
)
|
||||
|
||||
echo.
|
||||
echo Installed MySQL client versions:
|
||||
for %%v in (%mysql_versions%) do (
|
||||
set "version_dir=%MYSQL_DIR%\mysql-%%v"
|
||||
if exist "!version_dir!\bin\mysqldump.exe" (
|
||||
echo mysql-%%v: !version_dir!\bin\
|
||||
)
|
||||
)
|
||||
|
||||
echo.
|
||||
echo Installed MariaDB client versions:
|
||||
for %%v in (%mariadb_versions%) do (
|
||||
set "version_dir=%MARIADB_DIR%\mariadb-%%v"
|
||||
if exist "!version_dir!\bin\mariadb-dump.exe" (
|
||||
echo mariadb-%%v: !version_dir!\bin\
|
||||
)
|
||||
)
|
||||
|
||||
echo.
|
||||
echo Usage examples:
|
||||
echo %POSTGRES_DIR%\postgresql-15\bin\pg_dump.exe --version
|
||||
echo %MYSQL_DIR%\mysql-8.0\bin\mysqldump.exe --version
|
||||
echo %MARIADB_DIR%\mariadb-12.1\bin\mariadb-dump.exe --version
|
||||
echo.
|
||||
|
||||
pause
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
This directory is needed only for development and CI\CD.
|
||||
|
||||
We have to download and install all the PostgreSQL versions from 12 to 18 locally.
|
||||
This is needed so we can call pg_dump, pg_dumpall, etc. on each version of the PostgreSQL database.
|
||||
We have to download and install all the PostgreSQL versions from 12 to 18, MySQL versions 5.7, 8.0, 8.4 and MariaDB client tools locally.
|
||||
This is needed so we can call pg_dump, pg_restore, mysqldump, mysql, mariadb-dump, mariadb, etc. on each version of the database.
|
||||
|
||||
You do not need to install PostgreSQL fully with all the components.
|
||||
We only need the client tools (pg_dump, pg_dumpall, psql, etc.) for each version.
|
||||
You do not need to install the databases fully with all the components.
|
||||
We only need the client tools for each version.
|
||||
|
||||
We have to install the following:
|
||||
## Required Versions
|
||||
|
||||
### PostgreSQL
|
||||
|
||||
- PostgreSQL 12
|
||||
- PostgreSQL 13
|
||||
@@ -16,6 +18,21 @@ We have to install the following:
|
||||
- PostgreSQL 17
|
||||
- PostgreSQL 18
|
||||
|
||||
### MySQL
|
||||
|
||||
- MySQL 5.7
|
||||
- MySQL 8.0
|
||||
- MySQL 8.4
|
||||
|
||||
### MariaDB
|
||||
|
||||
MariaDB uses two client versions to support all server versions:
|
||||
|
||||
- MariaDB 10.6 (legacy client - for older servers 5.5 and 10.1)
|
||||
- MariaDB 12.1 (modern client - for servers 10.2+)
|
||||
|
||||
The reason for two versions is that MariaDB 12.1 client uses SQL queries that reference the `generation_expression` column in `information_schema.columns`, which was only added in MariaDB 10.2. Older servers (5.5, 10.1) don't have this column and fail with newer clients.
|
||||
|
||||
## Installation
|
||||
|
||||
Run the appropriate download script for your platform:
|
||||
@@ -45,12 +62,15 @@ chmod +x download_macos.sh
|
||||
### Windows
|
||||
|
||||
- Downloads official PostgreSQL installers from EnterpriseDB
|
||||
- Downloads official MySQL ZIP archives from dev.mysql.com
|
||||
- Installs client tools only (no server components)
|
||||
- May require administrator privileges during installation
|
||||
- May require administrator privileges during PostgreSQL installation
|
||||
|
||||
### Linux (Debian/Ubuntu)
|
||||
|
||||
- Uses the official PostgreSQL APT repository
|
||||
- Downloads MySQL client tools from official archives
|
||||
- Installs MariaDB client from official MariaDB repository
|
||||
- Requires sudo privileges to install packages
|
||||
- Creates symlinks in version-specific directories for consistency
|
||||
|
||||
@@ -58,17 +78,23 @@ chmod +x download_macos.sh
|
||||
|
||||
- Requires Homebrew to be installed
|
||||
- Compiles PostgreSQL from source (client tools only)
|
||||
- Takes longer than other platforms due to compilation
|
||||
- Downloads pre-built MySQL binaries from dev.mysql.com
|
||||
- Downloads pre-built MariaDB binaries or installs via Homebrew
|
||||
- Takes longer than other platforms due to PostgreSQL compilation
|
||||
- Supports both Intel (x86_64) and Apple Silicon (arm64)
|
||||
|
||||
## Manual Installation
|
||||
|
||||
If something goes wrong with the automated scripts, install manually.
|
||||
The final directory structure should match:
|
||||
|
||||
### PostgreSQL
|
||||
|
||||
```
|
||||
./tools/postgresql/postgresql-{version}/bin/pg_dump
|
||||
./tools/postgresql/postgresql-{version}/bin/pg_dumpall
|
||||
./tools/postgresql/postgresql-{version}/bin/psql
|
||||
./tools/postgresql/postgresql-{version}/bin/pg_restore
|
||||
```
|
||||
|
||||
For example:
|
||||
@@ -81,14 +107,112 @@ For example:
|
||||
- `./tools/postgresql/postgresql-17/bin/pg_dump`
|
||||
- `./tools/postgresql/postgresql-18/bin/pg_dump`
|
||||
|
||||
### MySQL
|
||||
|
||||
```
|
||||
./tools/mysql/mysql-{version}/bin/mysqldump
|
||||
./tools/mysql/mysql-{version}/bin/mysql
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
- `./tools/mysql/mysql-5.7/bin/mysqldump`
|
||||
- `./tools/mysql/mysql-8.0/bin/mysqldump`
|
||||
- `./tools/mysql/mysql-8.4/bin/mysqldump`
|
||||
|
||||
### MariaDB
|
||||
|
||||
MariaDB uses two client versions to handle compatibility with all server versions:
|
||||
|
||||
```
|
||||
./tools/mariadb/mariadb-{client-version}/bin/mariadb-dump
|
||||
./tools/mariadb/mariadb-{client-version}/bin/mariadb
|
||||
```
|
||||
|
||||
For example:
|
||||
|
||||
- `./tools/mariadb/mariadb-10.6/bin/mariadb-dump` (legacy - for servers 5.5, 10.1)
|
||||
- `./tools/mariadb/mariadb-12.1/bin/mariadb-dump` (modern - for servers 10.2+)
|
||||
|
||||
## Usage
|
||||
|
||||
After installation, you can use version-specific tools:
|
||||
|
||||
```bash
|
||||
# Windows
|
||||
# Windows - PostgreSQL
|
||||
./postgresql/postgresql-15/bin/pg_dump.exe --version
|
||||
|
||||
# Linux/MacOS
|
||||
# Windows - MySQL
|
||||
./mysql/mysql-8.0/bin/mysqldump.exe --version
|
||||
|
||||
# Windows - MariaDB
|
||||
./mariadb/mariadb-12.1/bin/mariadb-dump.exe --version
|
||||
|
||||
# Linux/MacOS - PostgreSQL
|
||||
./postgresql/postgresql-15/bin/pg_dump --version
|
||||
|
||||
# Linux/MacOS - MySQL
|
||||
./mysql/mysql-8.0/bin/mysqldump --version
|
||||
|
||||
# Linux/MacOS - MariaDB
|
||||
./mariadb/mariadb-12.1/bin/mariadb-dump --version
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
The application expects these environment variables to be set (or uses defaults):
|
||||
|
||||
```env
|
||||
# PostgreSQL tools directory (default: ./tools/postgresql)
|
||||
POSTGRES_INSTALL_DIR=C:\path\to\tools\postgresql
|
||||
|
||||
# MySQL tools directory (default: ./tools/mysql)
|
||||
MYSQL_INSTALL_DIR=C:\path\to\tools\mysql
|
||||
|
||||
# MariaDB tools directory (default: ./tools/mariadb)
|
||||
# Contains subdirectories: mariadb-10.6 and mariadb-12.1
|
||||
MARIADB_INSTALL_DIR=C:\path\to\tools\mariadb
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### MySQL 5.7 on Apple Silicon (M1/M2/M3)
|
||||
|
||||
MySQL 5.7 does not have native ARM64 binaries for macOS. The script will attempt to download the x86_64 version, which may work under Rosetta 2. If you encounter issues:
|
||||
|
||||
1. Ensure Rosetta 2 is installed: `softwareupdate --install-rosetta`
|
||||
2. Or skip MySQL 5.7 if you don't need to support that version
|
||||
|
||||
### Permission Errors on Linux
|
||||
|
||||
If you encounter permission errors, ensure you have sudo privileges:
|
||||
|
||||
```bash
|
||||
sudo ./download_linux.sh
|
||||
```
|
||||
|
||||
### Download Failures
|
||||
|
||||
If downloads fail, you can manually download the files:
|
||||
|
||||
- PostgreSQL: https://www.postgresql.org/ftp/source/
|
||||
- MySQL: https://dev.mysql.com/downloads/mysql/
|
||||
- MariaDB: https://mariadb.org/download/ or https://cdn.mysql.com/archives/mariadb-12.0/
|
||||
|
||||
### MariaDB Client Compatibility
|
||||
|
||||
MariaDB client tools require different versions depending on the server:
|
||||
|
||||
**Legacy client (10.6)** - Required for:
|
||||
|
||||
- MariaDB 5.5
|
||||
- MariaDB 10.1
|
||||
|
||||
**Modern client (12.1)** - Works with:
|
||||
|
||||
- MariaDB 10.2 - 10.6
|
||||
- MariaDB 10.11
|
||||
- MariaDB 11.4, 11.8
|
||||
- MariaDB 12.0
|
||||
|
||||
The reason is that MariaDB 12.1 client uses SQL queries referencing the `generation_expression` column in `information_schema.columns`, which was added in MariaDB 10.2. The application automatically selects the appropriate client version based on the target server version.
|
||||
|
||||
@@ -2,11 +2,21 @@ apiVersion: v2
|
||||
name: postgresus
|
||||
description: A Helm chart for Postgresus - PostgreSQL backup and management system
|
||||
type: application
|
||||
version: 1.0.0
|
||||
appVersion: "v1.45.3"
|
||||
version: 0.0.0
|
||||
appVersion: "latest"
|
||||
keywords:
|
||||
- postgresql
|
||||
- backup
|
||||
- database
|
||||
- restore
|
||||
home: https://github.com/RostislavDugin/postgresus
|
||||
|
||||
sources:
|
||||
- https://github.com/RostislavDugin/postgresus
|
||||
- https://github.com/RostislavDugin/postgresus/tree/main/deploy/helm
|
||||
|
||||
maintainers:
|
||||
- name: Rostislav Dugin
|
||||
url: https://github.com/RostislavDugin
|
||||
|
||||
icon: https://raw.githubusercontent.com/RostislavDugin/postgresus/main/frontend/public/logo.svg
|
||||
|
||||
@@ -2,17 +2,24 @@
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace
|
||||
```
|
||||
|
||||
After installation, get the external IP:
|
||||
Install directly from the OCI registry (no need to clone the repository):
|
||||
|
||||
```bash
|
||||
kubectl get svc -n postgresus
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace
|
||||
```
|
||||
|
||||
Access Postgresus at `http://<EXTERNAL-IP>` (port 80).
|
||||
The `-n postgresus --create-namespace` flags control which namespace the chart is installed into. You can use any namespace name you prefer.
|
||||
|
||||
## Accessing Postgresus
|
||||
|
||||
By default, the chart creates a ClusterIP service. Use port-forward to access:
|
||||
|
||||
```bash
|
||||
kubectl port-forward svc/postgresus-service 4005:4005 -n postgresus
|
||||
```
|
||||
|
||||
Then open `http://localhost:4005` in your browser.
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -20,21 +27,42 @@ Access Postgresus at `http://<EXTERNAL-IP>` (port 80).
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ------------------ | ------------------ | --------------------------- |
|
||||
| `namespace.create` | Create namespace | `true` |
|
||||
| `namespace.name` | Namespace name | `postgresus` |
|
||||
| `image.repository` | Docker image | `rostislavdugin/postgresus` |
|
||||
| `image.tag` | Image tag | `latest` |
|
||||
| `image.pullPolicy` | Image pull policy | `Always` |
|
||||
| `replicaCount` | Number of replicas | `1` |
|
||||
|
||||
### Resources
|
||||
### Custom Root CA
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --------------------------- | -------------- | ------------- |
|
||||
| `resources.requests.memory` | Memory request | `1Gi` |
|
||||
| `resources.requests.cpu` | CPU request | `500m` |
|
||||
| `resources.limits.memory` | Memory limit | `1Gi` |
|
||||
| `resources.limits.cpu` | CPU limit | `500m` |
|
||||
| Parameter | Description | Default Value |
|
||||
| -------------- | ---------------------------------------- | ------------- |
|
||||
| `customRootCA` | Name of Secret containing CA certificate | `""` |
|
||||
|
||||
To trust a custom CA certificate (e.g., for internal services with self-signed certificates):
|
||||
|
||||
1. Create a Secret with your CA certificate:
|
||||
|
||||
```bash
|
||||
kubectl create secret generic my-root-ca \
|
||||
--from-file=ca.crt=./path/to/ca-certificate.crt
|
||||
```
|
||||
|
||||
2. Reference it in values:
|
||||
|
||||
```yaml
|
||||
customRootCA: my-root-ca
|
||||
```
|
||||
|
||||
The certificate will be mounted to `/etc/ssl/certs/custom-root-ca.crt` and the `SSL_CERT_FILE` environment variable will be set automatically.
|
||||
|
||||
### Service
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| -------------------------- | ----------------------- | ------------- |
|
||||
| `service.type` | Service type | `ClusterIP` |
|
||||
| `service.port` | Service port | `4005` |
|
||||
| `service.targetPort` | Container port | `4005` |
|
||||
| `service.headless.enabled` | Enable headless service | `true` |
|
||||
|
||||
### Storage
|
||||
|
||||
@@ -46,93 +74,80 @@ Access Postgresus at `http://<EXTERNAL-IP>` (port 80).
|
||||
| `persistence.size` | Storage size | `10Gi` |
|
||||
| `persistence.mountPath` | Mount path | `/postgresus-data` |
|
||||
|
||||
### Service
|
||||
### Resources
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| -------------------------- | ----------------------- | -------------- |
|
||||
| `service.type` | Service type | `LoadBalancer` |
|
||||
| `service.port` | External port | `80` |
|
||||
| `service.targetPort` | Container port | `4005` |
|
||||
| `service.headless.enabled` | Enable headless service | `true` |
|
||||
| Parameter | Description | Default Value |
|
||||
| --------------------------- | -------------- | ------------- |
|
||||
| `resources.requests.memory` | Memory request | `1Gi` |
|
||||
| `resources.requests.cpu` | CPU request | `500m` |
|
||||
| `resources.limits.memory` | Memory limit | `1Gi` |
|
||||
| `resources.limits.cpu` | CPU limit | `500m` |
|
||||
|
||||
### Traffic Exposure (3 Options)
|
||||
## External Access Options
|
||||
|
||||
The chart supports 3 ways to expose Postgresus:
|
||||
### Option 1: Port Forward (Default)
|
||||
|
||||
| Method | Use Case | Default |
|
||||
| ------ | -------- | ------- |
|
||||
| **LoadBalancer/NodePort** | Simple cloud clusters | Enabled |
|
||||
| **Ingress** | Traditional nginx/traefik ingress controllers | Disabled |
|
||||
| **HTTPRoute (Gateway API)** | Modern gateways (Istio, Envoy, Cilium) | Disabled |
|
||||
|
||||
#### Ingress
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ----------------------- | ----------------- | ------------------------ |
|
||||
| `ingress.enabled` | Enable Ingress | `false` |
|
||||
| `ingress.className` | Ingress class | `nginx` |
|
||||
| `ingress.hosts[0].host` | Hostname | `postgresus.example.com` |
|
||||
| `ingress.tls` | TLS configuration | `[]` |
|
||||
|
||||
#### HTTPRoute (Gateway API)
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| --------------------- | -------------------------- | ---------------------------------- |
|
||||
| `route.enabled` | Enable HTTPRoute | `false` |
|
||||
| `route.apiVersion` | Gateway API version | `gateway.networking.k8s.io/v1` |
|
||||
| `route.hostnames` | Hostnames for the route | `["postgresus.example.com"]` |
|
||||
| `route.parentRefs` | Gateway references | `[]` |
|
||||
| `route.annotations` | Route annotations | `{}` |
|
||||
|
||||
### Health Checks
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ------------------------ | ---------------------- | ------------- |
|
||||
| `livenessProbe.enabled` | Enable liveness probe | `true` |
|
||||
| `readinessProbe.enabled` | Enable readiness probe | `true` |
|
||||
|
||||
## Examples
|
||||
|
||||
### Basic Installation (LoadBalancer on port 80)
|
||||
|
||||
Default installation exposes Postgresus via LoadBalancer on port 80:
|
||||
Best for development or quick access:
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace
|
||||
kubectl port-forward svc/postgresus-service 4005:4005 -n postgresus
|
||||
```
|
||||
|
||||
Access via `http://<EXTERNAL-IP>`
|
||||
Access at `http://localhost:4005`
|
||||
|
||||
### Using NodePort
|
||||
### Option 2: NodePort
|
||||
|
||||
If your cluster doesn't support LoadBalancer:
|
||||
For direct access via node IP:
|
||||
|
||||
```yaml
|
||||
# nodeport-values.yaml
|
||||
service:
|
||||
type: NodePort
|
||||
port: 80
|
||||
port: 4005
|
||||
targetPort: 4005
|
||||
nodePort: 30080
|
||||
```
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace -f nodeport-values.yaml
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
-f nodeport-values.yaml
|
||||
```
|
||||
|
||||
Access via `http://<NODE-IP>:30080`
|
||||
Access at `http://<NODE-IP>:30080`
|
||||
|
||||
### Enable Ingress with HTTPS
|
||||
### Option 3: LoadBalancer
|
||||
|
||||
For cloud environments with load balancer support:
|
||||
|
||||
```yaml
|
||||
# loadbalancer-values.yaml
|
||||
service:
|
||||
type: LoadBalancer
|
||||
port: 80
|
||||
targetPort: 4005
|
||||
```
|
||||
|
||||
```bash
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
-f loadbalancer-values.yaml
|
||||
```
|
||||
|
||||
Get the external IP:
|
||||
|
||||
```bash
|
||||
kubectl get svc -n postgresus
|
||||
```
|
||||
|
||||
Access at `http://<EXTERNAL-IP>`
|
||||
|
||||
### Option 4: Ingress
|
||||
|
||||
For domain-based access with TLS:
|
||||
|
||||
```yaml
|
||||
# ingress-values.yaml
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 4005
|
||||
targetPort: 4005
|
||||
|
||||
ingress:
|
||||
enabled: true
|
||||
className: nginx
|
||||
@@ -151,18 +166,17 @@ ingress:
|
||||
```
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace -f ingress-values.yaml
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
-f ingress-values.yaml
|
||||
```
|
||||
|
||||
### HTTPRoute (Gateway API)
|
||||
### Option 5: HTTPRoute (Gateway API)
|
||||
|
||||
For clusters using Istio, Envoy Gateway, Cilium, or other Gateway API implementations:
|
||||
|
||||
```yaml
|
||||
# httproute-values.yaml
|
||||
service:
|
||||
type: ClusterIP
|
||||
|
||||
route:
|
||||
enabled: true
|
||||
hostnames:
|
||||
@@ -173,10 +187,37 @@ route:
|
||||
```
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace -f httproute-values.yaml
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
-f httproute-values.yaml
|
||||
```
|
||||
|
||||
### Custom Storage Size
|
||||
## Ingress Configuration
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ----------------------- | ----------------- | ------------------------ |
|
||||
| `ingress.enabled` | Enable Ingress | `false` |
|
||||
| `ingress.className` | Ingress class | `nginx` |
|
||||
| `ingress.hosts[0].host` | Hostname | `postgresus.example.com` |
|
||||
| `ingress.tls` | TLS configuration | `[]` |
|
||||
|
||||
## HTTPRoute Configuration
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ------------------ | ----------------------- | ------------------------------ |
|
||||
| `route.enabled` | Enable HTTPRoute | `false` |
|
||||
| `route.apiVersion` | Gateway API version | `gateway.networking.k8s.io/v1` |
|
||||
| `route.hostnames` | Hostnames for the route | `["postgresus.example.com"]` |
|
||||
| `route.parentRefs` | Gateway references | `[]` |
|
||||
|
||||
## Health Checks
|
||||
|
||||
| Parameter | Description | Default Value |
|
||||
| ------------------------ | ---------------------- | ------------- |
|
||||
| `livenessProbe.enabled` | Enable liveness probe | `true` |
|
||||
| `readinessProbe.enabled` | Enable readiness probe | `true` |
|
||||
|
||||
## Custom Storage Size
|
||||
|
||||
```yaml
|
||||
# storage-values.yaml
|
||||
@@ -186,5 +227,19 @@ persistence:
|
||||
```
|
||||
|
||||
```bash
|
||||
helm install postgresus ./deploy/helm -n postgresus --create-namespace -f storage-values.yaml
|
||||
helm install postgresus oci://ghcr.io/rostislavdugin/charts/postgresus \
|
||||
-n postgresus --create-namespace \
|
||||
-f storage-values.yaml
|
||||
```
|
||||
|
||||
## Upgrade
|
||||
|
||||
```bash
|
||||
helm upgrade postgresus oci://ghcr.io/rostislavdugin/charts/postgresus -n postgresus
|
||||
```
|
||||
|
||||
## Uninstall
|
||||
|
||||
```bash
|
||||
helm uninstall postgresus -n postgresus
|
||||
```
|
||||
|
||||
@@ -61,12 +61,8 @@ Create the name of the service account to use
|
||||
{{- end }}
|
||||
|
||||
{{/*
|
||||
Namespace
|
||||
Namespace - uses the release namespace from helm install -n <namespace>
|
||||
*/}}
|
||||
{{- define "postgresus.namespace" -}}
|
||||
{{- if .Values.namespace.create }}
|
||||
{{- .Values.namespace.name }}
|
||||
{{- else }}
|
||||
{{- .Release.Namespace }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
{{- if .Values.namespace.create }}
|
||||
apiVersion: v1
|
||||
kind: Namespace
|
||||
metadata:
|
||||
name: {{ .Values.namespace.name }}
|
||||
labels:
|
||||
{{- include "postgresus.labels" . | nindent 4 }}
|
||||
{{- end }}
|
||||
@@ -39,6 +39,11 @@ spec:
|
||||
- name: {{ .Chart.Name }}
|
||||
image: "{{ .Values.image.repository }}:{{ .Values.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.image.pullPolicy }}
|
||||
{{- if .Values.customRootCA }}
|
||||
env:
|
||||
- name: SSL_CERT_FILE
|
||||
value: /etc/ssl/certs/custom-root-ca.crt
|
||||
{{- end }}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.service.targetPort }}
|
||||
@@ -46,6 +51,12 @@ spec:
|
||||
volumeMounts:
|
||||
- name: postgresus-storage
|
||||
mountPath: {{ .Values.persistence.mountPath }}
|
||||
{{- if .Values.customRootCA }}
|
||||
- name: custom-root-ca
|
||||
mountPath: /etc/ssl/certs/custom-root-ca.crt
|
||||
subPath: ca.crt
|
||||
readOnly: true
|
||||
{{- end }}
|
||||
resources:
|
||||
{{- toYaml .Values.resources | nindent 12 }}
|
||||
{{- if .Values.livenessProbe.enabled }}
|
||||
@@ -66,6 +77,12 @@ spec:
|
||||
timeoutSeconds: {{ .Values.readinessProbe.timeoutSeconds }}
|
||||
failureThreshold: {{ .Values.readinessProbe.failureThreshold }}
|
||||
{{- end }}
|
||||
{{- if .Values.customRootCA }}
|
||||
volumes:
|
||||
- name: custom-root-ca
|
||||
secret:
|
||||
secretName: {{ .Values.customRootCA }}
|
||||
{{- end }}
|
||||
{{- if .Values.persistence.enabled }}
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user