Files
databasus/backend/internal/features/databases/databases/mariadb/model.go
2025-12-25 21:26:53 +03:00

433 lines
11 KiB
Go

package mariadb
import (
"context"
"database/sql"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"databasus-backend/internal/util/encryption"
"databasus-backend/internal/util/tools"
_ "github.com/go-sql-driver/mysql"
"github.com/google/uuid"
)
type MariadbDatabase struct {
ID uuid.UUID `json:"id" gorm:"primaryKey;type:uuid;default:gen_random_uuid()"`
DatabaseID *uuid.UUID `json:"databaseId" gorm:"type:uuid;column:database_id"`
Version tools.MariadbVersion `json:"version" gorm:"type:text;not null"`
Host string `json:"host" gorm:"type:text;not null"`
Port int `json:"port" gorm:"type:int;not null"`
Username string `json:"username" gorm:"type:text;not null"`
Password string `json:"password" gorm:"type:text;not null"`
Database *string `json:"database" gorm:"type:text"`
IsHttps bool `json:"isHttps" gorm:"type:boolean;default:false"`
}
func (m *MariadbDatabase) TableName() string {
return "mariadb_databases"
}
func (m *MariadbDatabase) Validate() error {
if m.Host == "" {
return errors.New("host is required")
}
if m.Port == 0 {
return errors.New("port is required")
}
if m.Username == "" {
return errors.New("username is required")
}
if m.Password == "" {
return errors.New("password is required")
}
return nil
}
func (m *MariadbDatabase) TestConnection(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if m.Database == nil || *m.Database == "" {
return errors.New("database name is required for MariaDB backup")
}
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to MariaDB database '%s': %w", *m.Database, err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close MariaDB connection", "error", closeErr)
}
}()
db.SetConnMaxLifetime(15 * time.Second)
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
if err := db.PingContext(ctx); err != nil {
return fmt.Errorf("failed to ping MariaDB database '%s': %w", *m.Database, err)
}
detectedVersion, err := detectMariadbVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
func (m *MariadbDatabase) HideSensitiveData() {
if m == nil {
return
}
m.Password = ""
}
func (m *MariadbDatabase) Update(incoming *MariadbDatabase) {
m.Version = incoming.Version
m.Host = incoming.Host
m.Port = incoming.Port
m.Username = incoming.Username
m.Database = incoming.Database
m.IsHttps = incoming.IsHttps
if incoming.Password != "" {
m.Password = incoming.Password
}
}
func (m *MariadbDatabase) EncryptSensitiveFields(
databaseID uuid.UUID,
encryptor encryption.FieldEncryptor,
) error {
if m.Password != "" {
encrypted, err := encryptor.Encrypt(databaseID, m.Password)
if err != nil {
return err
}
m.Password = encrypted
}
return nil
}
func (m *MariadbDatabase) PopulateVersionIfEmpty(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Version != "" {
return nil
}
return m.PopulateVersion(logger, encryptor, databaseID)
}
func (m *MariadbDatabase) PopulateVersion(
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) error {
if m.Database == nil || *m.Database == "" {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
detectedVersion, err := detectMariadbVersion(ctx, db)
if err != nil {
return err
}
m.Version = detectedVersion
return nil
}
func (m *MariadbDatabase) IsUserReadOnly(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (bool, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return false, fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return false, fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
rows, err := db.QueryContext(ctx, "SHOW GRANTS FOR CURRENT_USER()")
if err != nil {
return false, fmt.Errorf("failed to check grants: %w", err)
}
defer func() { _ = rows.Close() }()
writePrivileges := []string{
"INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER",
"INDEX", "GRANT OPTION", "ALL PRIVILEGES", "SUPER",
}
for rows.Next() {
var grant string
if err := rows.Scan(&grant); err != nil {
return false, fmt.Errorf("failed to scan grant: %w", err)
}
for _, priv := range writePrivileges {
if regexp.MustCompile(`(?i)\b` + priv + `\b`).MatchString(grant) {
return false, nil
}
}
}
if err := rows.Err(); err != nil {
return false, fmt.Errorf("error iterating grants: %w", err)
}
return true, nil
}
func (m *MariadbDatabase) CreateReadOnlyUser(
ctx context.Context,
logger *slog.Logger,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, string, error) {
password, err := decryptPasswordIfNeeded(m.Password, encryptor, databaseID)
if err != nil {
return "", "", fmt.Errorf("failed to decrypt password: %w", err)
}
dsn := m.buildDSN(password, *m.Database)
db, err := sql.Open("mysql", dsn)
if err != nil {
return "", "", fmt.Errorf("failed to connect to database: %w", err)
}
defer func() {
if closeErr := db.Close(); closeErr != nil {
logger.Error("Failed to close connection", "error", closeErr)
}
}()
maxRetries := 3
for attempt := range maxRetries {
// MariaDB 5.5 has a 16-character username limit, use shorter prefix
newUsername := fmt.Sprintf("pgs-%s", uuid.New().String()[:8])
newPassword := uuid.New().String()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return "", "", fmt.Errorf("failed to begin transaction: %w", err)
}
success := false
defer func() {
if !success {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
logger.Error("Failed to rollback transaction", "error", rollbackErr)
}
}
}()
_, err = tx.ExecContext(ctx, fmt.Sprintf(
"CREATE USER '%s'@'%%' IDENTIFIED BY '%s'",
newUsername,
newPassword,
))
if err != nil {
if attempt < maxRetries-1 {
continue
}
return "", "", fmt.Errorf("failed to create user: %w", err)
}
_, err = tx.ExecContext(ctx, fmt.Sprintf(
"GRANT SELECT, SHOW VIEW, LOCK TABLES, TRIGGER, EVENT ON `%s`.* TO '%s'@'%%'",
*m.Database,
newUsername,
))
if err != nil {
return "", "", fmt.Errorf("failed to grant database privileges: %w", err)
}
_, err = tx.ExecContext(ctx, fmt.Sprintf(
"GRANT PROCESS ON *.* TO '%s'@'%%'",
newUsername,
))
if err != nil {
return "", "", fmt.Errorf("failed to grant PROCESS privilege: %w", err)
}
_, err = tx.ExecContext(ctx, "FLUSH PRIVILEGES")
if err != nil {
return "", "", fmt.Errorf("failed to flush privileges: %w", err)
}
if err := tx.Commit(); err != nil {
return "", "", fmt.Errorf("failed to commit transaction: %w", err)
}
success = true
logger.Info(
"Read-only MariaDB user created successfully",
"username", newUsername,
)
return newUsername, newPassword, nil
}
return "", "", errors.New("failed to generate unique username after 3 attempts")
}
func (m *MariadbDatabase) buildDSN(password string, database string) string {
tlsConfig := "false"
if m.IsHttps {
tlsConfig = "true"
}
return fmt.Sprintf(
"%s:%s@tcp(%s:%d)/%s?parseTime=true&timeout=15s&tls=%s&charset=utf8mb4",
m.Username,
password,
m.Host,
m.Port,
database,
tlsConfig,
)
}
// detectMariadbVersion parses VERSION() output to detect MariaDB version
// MariaDB returns strings like "10.11.6-MariaDB" or "11.4.2-MariaDB-1:11.4.2+maria~ubu2204"
// Minor versions are mapped to the closest supported version (e.g., 12.1 → 12.0)
func detectMariadbVersion(ctx context.Context, db *sql.DB) (tools.MariadbVersion, error) {
var versionStr string
err := db.QueryRowContext(ctx, "SELECT VERSION()").Scan(&versionStr)
if err != nil {
return "", fmt.Errorf("failed to query MariaDB version: %w", err)
}
if !strings.Contains(strings.ToLower(versionStr), "mariadb") {
return "", fmt.Errorf(
"not a MariaDB server (version: %s). Use MySQL database type instead",
versionStr,
)
}
re := regexp.MustCompile(`^(\d+)\.(\d+)`)
matches := re.FindStringSubmatch(versionStr)
if len(matches) < 3 {
return "", fmt.Errorf("could not parse MariaDB version: %s", versionStr)
}
major := matches[1]
minor := matches[2]
return mapMariadbVersion(major, minor)
}
func mapMariadbVersion(major, minor string) (tools.MariadbVersion, error) {
switch major {
case "5":
return tools.MariadbVersion55, nil
case "10":
return mapMariadb10xVersion(minor)
case "11":
return mapMariadb11xVersion(minor)
case "12":
return tools.MariadbVersion120, nil
default:
return "", fmt.Errorf(
"unsupported MariaDB major version: %s (supported: 5.x, 10.x, 11.x, 12.x)",
major,
)
}
}
func mapMariadb10xVersion(minor string) (tools.MariadbVersion, error) {
switch minor {
case "1":
return tools.MariadbVersion101, nil
case "2":
return tools.MariadbVersion102, nil
case "3":
return tools.MariadbVersion103, nil
case "4":
return tools.MariadbVersion104, nil
case "5":
return tools.MariadbVersion105, nil
case "6", "7", "8", "9", "10":
return tools.MariadbVersion106, nil
default:
return tools.MariadbVersion1011, nil
}
}
func mapMariadb11xVersion(minor string) (tools.MariadbVersion, error) {
switch minor {
case "0", "1", "2", "3", "4":
return tools.MariadbVersion114, nil
case "5", "6", "7", "8":
return tools.MariadbVersion118, nil
default:
return tools.MariadbVersion118, nil
}
}
func decryptPasswordIfNeeded(
password string,
encryptor encryption.FieldEncryptor,
databaseID uuid.UUID,
) (string, error) {
if encryptor == nil {
return password, nil
}
return encryptor.Decrypt(databaseID, password)
}