mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
433 lines
11 KiB
Go
433 lines
11 KiB
Go
package mariadb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
"databasus-backend/internal/util/encryption"
|
|
"databasus-backend/internal/util/tools"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type MariadbDatabase struct {
|
|
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
|
|
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
|
|
|
|
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
|
|
|
|
Host string `json:"host" gorm:"type:text;not null"`
|
|
Port int `json:"port" gorm:"type:int;not null"`
|
|
Username string `json:"username" gorm:"type:text;not null"`
|
|
Password string `json:"password" gorm:"type:text;not null"`
|
|
Database *string `json:"database" gorm:"type:text"`
|
|
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
|
|
}
|
|
|
|
func (m *MariadbDatabase) TableName() string {
|
|
return "mariadb_databases"
|
|
}
|
|
|
|
func (m *MariadbDatabase) Validate() error {
|
|
if m.Host == "" {
|
|
return errors.New("host is required")
|
|
}
|
|
if m.Port == 0 {
|
|
return errors.New("port is required")
|
|
}
|
|
if m.Username == "" {
|
|
return errors.New("username is required")
|
|
}
|
|
if m.Password == "" {
|
|
return errors.New("password is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MariadbDatabase) TestConnection(
|
|
logger *slog.Logger,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
defer cancel()
|
|
|
|
if m.Database == nil || *m.Database == "" {
|
|
return errors.New("database name is required for MariaDB backup")
|
|
}
|
|
|
|
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
dsn := m.buildDSN(password, *m.Database)
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to MariaDB database '%s': %w", *m.Database, err)
|
|
}
|
|
defer func() {
|
|
if closeErr := db.Close(); closeErr != nil {
|
|
logger.Error("Failed to close MariaDB connection", "error", closeErr)
|
|
}
|
|
}()
|
|
|
|
db.SetConnMaxLifetime(15 * time.Second)
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
|
|
if err := db.PingContext(ctx); err != nil {
|
|
return fmt.Errorf("failed to ping MariaDB database '%s': %w", *m.Database, err)
|
|
}
|
|
|
|
detectedVersion, err := detectMariadbVersion(ctx, db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.Version = detectedVersion
|
|
|
|
return nil
|
|
}
|
|
|
|
func (m *MariadbDatabase) HideSensitiveData() {
|
|
if m == nil {
|
|
return
|
|
}
|
|
m.Password = ""
|
|
}
|
|
|
|
func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
|
|
m.Version = incoming.Version
|
|
m.Host = incoming.Host
|
|
m.Port = incoming.Port
|
|
m.Username = incoming.Username
|
|
m.Database = incoming.Database
|
|
m.IsHttps = incoming.IsHttps
|
|
|
|
if incoming.Password != "" {
|
|
m.Password = incoming.Password
|
|
}
|
|
}
|
|
|
|
func (m *MariadbDatabase) EncryptSensitiveFields(
|
|
databaseID uuid.UUID,
|
|
encryptor encryption.FieldEncryptor,
|
|
) error {
|
|
if m.Password != "" {
|
|
encrypted, err := encryptor.Encrypt(databaseID, m.Password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.Password = encrypted
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *MariadbDatabase) PopulateVersionIfEmpty(
|
|
logger *slog.Logger,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) error {
|
|
if m.Version != "" {
|
|
return nil
|
|
}
|
|
return m.PopulateVersion(logger, encryptor, databaseID)
|
|
}
|
|
|
|
func (m *MariadbDatabase) PopulateVersion(
|
|
logger *slog.Logger,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) error {
|
|
if m.Database == nil || *m.Database == "" {
|
|
return nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
|
defer cancel()
|
|
|
|
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
dsn := m.buildDSN(password, *m.Database)
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
defer func() {
|
|
if closeErr := db.Close(); closeErr != nil {
|
|
logger.Error("Failed to close connection", "error", closeErr)
|
|
}
|
|
}()
|
|
|
|
detectedVersion, err := detectMariadbVersion(ctx, db)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m.Version = detectedVersion
|
|
return nil
|
|
}
|
|
|
|
func (m *MariadbDatabase) IsUserReadOnly(
|
|
ctx context.Context,
|
|
logger *slog.Logger,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) (bool, error) {
|
|
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
dsn := m.buildDSN(password, *m.Database)
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
defer func() {
|
|
if closeErr := db.Close(); closeErr != nil {
|
|
logger.Error("Failed to close connection", "error", closeErr)
|
|
}
|
|
}()
|
|
|
|
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check grants: %w", err)
|
|
}
|
|
defer func() { _ = rows.Close() }()
|
|
|
|
writePrivileges := []string{
|
|
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
|
|
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
|
|
}
|
|
|
|
for rows.Next() {
|
|
var grant string
|
|
if err := rows.Scan(&grant); err != nil {
|
|
return false, fmt.Errorf("failed to scan grant: %w", err)
|
|
}
|
|
|
|
for _, priv := range writePrivileges {
|
|
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
|
|
return false, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
return false, fmt.Errorf("error iterating grants: %w", err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (m *MariadbDatabase) CreateReadOnlyUser(
|
|
ctx context.Context,
|
|
logger *slog.Logger,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) (string, string, error) {
|
|
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
|
|
}
|
|
|
|
dsn := m.buildDSN(password, *m.Database)
|
|
|
|
db, err := sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
defer func() {
|
|
if closeErr := db.Close(); closeErr != nil {
|
|
logger.Error("Failed to close connection", "error", closeErr)
|
|
}
|
|
}()
|
|
|
|
maxRetries := 3
|
|
for attempt := range maxRetries {
|
|
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
|
|
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
|
|
newPassword := uuid.New().String()
|
|
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
|
|
}
|
|
|
|
success := false
|
|
defer func() {
|
|
if !success {
|
|
if rollbackErr := tx.Rollback(); rollbackErr != nil {
|
|
logger.Error("Failed to rollback transaction", "error", rollbackErr)
|
|
}
|
|
}
|
|
}()
|
|
|
|
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
|
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
|
|
newUsername,
|
|
newPassword,
|
|
))
|
|
if err != nil {
|
|
if attempt < maxRetries-1 {
|
|
continue
|
|
}
|
|
return "", "", fmt.Errorf("failed to create user: %w", err)
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
|
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
|
|
*m.Database,
|
|
newUsername,
|
|
))
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to grant database privileges: %w", err)
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, fmt.Sprintf(
|
|
"GRANT PROCESS ON *.* TO '%s'@'%%'",
|
|
newUsername,
|
|
))
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to grant PROCESS privilege: %w", err)
|
|
}
|
|
|
|
_, err = tx.ExecContext(ctx, "FLUSH PRIVILEGES")
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to flush privileges: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
|
|
}
|
|
|
|
success = true
|
|
logger.Info(
|
|
"Read-only MariaDB user created successfully",
|
|
"username", newUsername,
|
|
)
|
|
return newUsername, newPassword, nil
|
|
}
|
|
|
|
return "", "", errors.New("failed to generate unique username after 3 attempts")
|
|
}
|
|
|
|
func (m *MariadbDatabase) buildDSN(password string, database string) string {
|
|
tlsConfig := "false"
|
|
if m.IsHttps {
|
|
tlsConfig = "true"
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
|
|
m.Username,
|
|
password,
|
|
m.Host,
|
|
m.Port,
|
|
database,
|
|
tlsConfig,
|
|
)
|
|
}
|
|
|
|
// detectMariadbVersion parses VERSION() output to detect MariaDB version
|
|
// MariaDB returns strings like "10.11.6-MariaDB" or "11.4.2-MariaDB-1:11.4.2+maria~ubu2204"
|
|
// Minor versions are mapped to the closest supported version (e.g., 12.1 → 12.0)
|
|
func detectMariadbVersion(ctx context.Context, db *sql.DB) (tools.MariadbVersion, error) {
|
|
var versionStr string
|
|
err := db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&versionStr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to query MariaDB version: %w", err)
|
|
}
|
|
|
|
if !strings.Contains(strings.ToLower(versionStr), "mariadb") {
|
|
return "", fmt.Errorf(
|
|
"not a MariaDB server (version: %s). Use MySQL database type instead",
|
|
versionStr,
|
|
)
|
|
}
|
|
|
|
re := regexp.MustCompile(`^(\d+)\.(\d+)`)
|
|
matches := re.FindStringSubmatch(versionStr)
|
|
if len(matches) < 3 {
|
|
return "", fmt.Errorf("could not parse MariaDB version: %s", versionStr)
|
|
}
|
|
|
|
major := matches[1]
|
|
minor := matches[2]
|
|
|
|
return mapMariadbVersion(major, minor)
|
|
}
|
|
|
|
func mapMariadbVersion(major, minor string) (tools.MariadbVersion, error) {
|
|
switch major {
|
|
case "5":
|
|
return tools.MariadbVersion55, nil
|
|
case "10":
|
|
return mapMariadb10xVersion(minor)
|
|
case "11":
|
|
return mapMariadb11xVersion(minor)
|
|
case "12":
|
|
return tools.MariadbVersion120, nil
|
|
default:
|
|
return "", fmt.Errorf(
|
|
"unsupported MariaDB major version: %s (supported: 5.x, 10.x, 11.x, 12.x)",
|
|
major,
|
|
)
|
|
}
|
|
}
|
|
|
|
func mapMariadb10xVersion(minor string) (tools.MariadbVersion, error) {
|
|
switch minor {
|
|
case "1":
|
|
return tools.MariadbVersion101, nil
|
|
case "2":
|
|
return tools.MariadbVersion102, nil
|
|
case "3":
|
|
return tools.MariadbVersion103, nil
|
|
case "4":
|
|
return tools.MariadbVersion104, nil
|
|
case "5":
|
|
return tools.MariadbVersion105, nil
|
|
case "6", "7", "8", "9", "10":
|
|
return tools.MariadbVersion106, nil
|
|
default:
|
|
return tools.MariadbVersion1011, nil
|
|
}
|
|
}
|
|
|
|
func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
|
|
switch minor {
|
|
case "0", "1", "2", "3", "4":
|
|
return tools.MariadbVersion114, nil
|
|
case "5", "6", "7", "8":
|
|
return tools.MariadbVersion118, nil
|
|
default:
|
|
return tools.MariadbVersion118, nil
|
|
}
|
|
}
|
|
|
|
func decryptPasswordIfNeeded(
|
|
password string,
|
|
encryptor encryption.FieldEncryptor,
|
|
databaseID uuid.UUID,
|
|
) (string, error) {
|
|
if encryptor == nil {
|
|
return password, nil
|
|
}
|
|
return encryptor.Decrypt(databaseID, password)
|
|
}
|