mirror of
https://github.com/databasus/databasus.git
synced 2026-04-06 00:32:03 +02:00
1262 lines
38 KiB
Go
1262 lines
38 KiB
Go
package billing
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"databasus-backend/internal/config"
|
|
billing_models "databasus-backend/internal/features/billing/models"
|
|
billing_provider "databasus-backend/internal/features/billing/provider"
|
|
billing_repositories "databasus-backend/internal/features/billing/repositories"
|
|
"databasus-backend/internal/features/databases"
|
|
users_models "databasus-backend/internal/features/users/models"
|
|
workspaces_services "databasus-backend/internal/features/workspaces/services"
|
|
"databasus-backend/internal/storage"
|
|
"databasus-backend/internal/util/logger"
|
|
)
|
|
|
|
const billingTickerInterval = 5 * time.Minute
|
|
|
|
type BillingService struct {
|
|
subscriptionRepository *billing_repositories.SubscriptionRepository
|
|
subscriptionEventRepository *billing_repositories.SubscriptionEventRepository
|
|
invoiceRepository *billing_repositories.InvoiceRepository
|
|
|
|
billingProvider billing_provider.BillingProvider
|
|
workspaceService *workspaces_services.WorkspaceService
|
|
databaseService databases.DatabaseService
|
|
|
|
runOnce sync.Once
|
|
hasRun atomic.Bool
|
|
}
|
|
|
|
func (s *BillingService) Run(ctx context.Context, logger slog.Logger) {
|
|
wasAlreadyRun := s.hasRun.Load()
|
|
|
|
s.runOnce.Do(func() {
|
|
s.hasRun.Store(true)
|
|
|
|
ticker := time.NewTicker(billingTickerInterval)
|
|
defer ticker.Stop()
|
|
|
|
// Run immediately on start
|
|
expiredSubsLog := logger.With("task_name", "process_expired_subscriptions")
|
|
if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil {
|
|
expiredSubsLog.Error("failed to process expired subscriptions", "error", err)
|
|
}
|
|
|
|
expiredTrialsLog := logger.With("task_name", "process_expired_trials")
|
|
if err := s.processExpiredTrials(expiredTrialsLog); err != nil {
|
|
expiredTrialsLog.Error("failed to process expired trials", "error", err)
|
|
}
|
|
|
|
reconcileSubsLog := logger.With("task_name", "reconcile_subscriptions")
|
|
if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil {
|
|
reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err)
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if err := s.processExpiredSubscriptions(expiredSubsLog); err != nil {
|
|
expiredSubsLog.Error("failed to process expired subscriptions", "error", err)
|
|
}
|
|
|
|
if err := s.processExpiredTrials(expiredTrialsLog); err != nil {
|
|
expiredTrialsLog.Error("failed to process expired trials", "error", err)
|
|
}
|
|
|
|
if err := s.reconcileSubscriptions(reconcileSubsLog); err != nil {
|
|
reconcileSubsLog.Error("failed to reconcile subscriptions", "error", err)
|
|
}
|
|
}
|
|
}
|
|
})
|
|
|
|
if wasAlreadyRun {
|
|
panic(fmt.Sprintf("%T.Run() called multiple times", s))
|
|
}
|
|
}
|
|
|
|
func (s *BillingService) SetBillingProvider(provider billing_provider.BillingProvider) {
|
|
s.billingProvider = provider
|
|
}
|
|
|
|
func (s *BillingService) OnDatabaseCreated(databaseID uuid.UUID) {
|
|
scopedLog := logger.GetLogger().With("database_id", databaseID)
|
|
if err := s.createTrialSubscription(scopedLog, databaseID); err != nil {
|
|
scopedLog.Error("failed to create trial subscription", "error", err)
|
|
}
|
|
}
|
|
|
|
func (s *BillingService) CreateSubscription(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
databaseID uuid.UUID,
|
|
storageGB int,
|
|
) (checkoutURL string, err error) {
|
|
logger.Debug(fmt.Sprintf("creating subscription for storage %d GB", storageGB))
|
|
|
|
if err := s.validateDatabaseAccess(logger, user, databaseID); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// validate size
|
|
if storageGB < config.GetEnv().MinStorageGB || storageGB > config.GetEnv().MaxStorageGB {
|
|
logger.Error(
|
|
fmt.Sprintf(
|
|
"invalid storage requested: %d GB (allowed %d - %d)",
|
|
storageGB,
|
|
config.GetEnv().MinStorageGB,
|
|
config.GetEnv().MaxStorageGB,
|
|
),
|
|
)
|
|
return "", ErrInvalidStorage
|
|
}
|
|
|
|
// validate active subs (trial is allowed — it will be expired when the paid subscription activates)
|
|
existingSub, err := s.getActiveSubscription(logger, databaseID)
|
|
if err != nil && !errors.Is(err, ErrNoActiveSubscription) {
|
|
logger.Error("failed to check existing subscriptions", "error", err)
|
|
return "", err
|
|
}
|
|
|
|
if existingSub != nil && existingSub.Status != billing_models.StatusTrial {
|
|
logger.Error("active subscription already exists")
|
|
return "", ErrAlreadySubscribed
|
|
}
|
|
|
|
// create checkout session
|
|
url, err := s.billingProvider.CreateCheckoutSession(logger, billing_provider.CheckoutRequest{
|
|
DatabaseID: databaseID,
|
|
Email: user.Email,
|
|
StorageGB: storageGB,
|
|
SuccessURL: "https://app.databasus.com/?payment_success=" + databaseID.String(),
|
|
CancelURL: "https://app.databasus.com/?payment_failed=" + databaseID.String(),
|
|
})
|
|
if err != nil {
|
|
logger.Error("failed to create checkout session", "error", err)
|
|
return "", err
|
|
}
|
|
|
|
logger.Debug("checkout session created", "url", url)
|
|
|
|
return url, nil
|
|
}
|
|
|
|
// SyncSubscriptionFromProvider - syncs subscription state from provider webhook (subscription.updated).
|
|
// Always applies quantity, status, and period from the webhook. Determines event type by comparing
|
|
// old vs new storage: upgraded, downgraded, or new billing cycle started.
|
|
// Important note: this is not the same as payment. Payments come separately via RecordPaymentSuccess.
|
|
func (s *BillingService) SyncSubscriptionFromProvider(
|
|
logger *slog.Logger,
|
|
subscription *billing_models.Subscription,
|
|
webhookEvent billing_models.WebhookEvent,
|
|
) error {
|
|
logger = logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
|
|
logger.Debug("syncing subscription state from provider")
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", subscription.ID).First(&subscription).Error; err != nil {
|
|
logger.Error("failed to lock subscription for provider sync", "error", err)
|
|
return err
|
|
}
|
|
|
|
oldStorageGB := subscription.StorageGB
|
|
oldStatus := subscription.Status
|
|
|
|
subscription.StorageGB = webhookEvent.QuantityGB
|
|
subscription.PendingStorageGB = nil
|
|
subscription.Status = webhookEvent.Status
|
|
subscription.UpdatedAt = time.Now().UTC()
|
|
|
|
if webhookEvent.PeriodStart != nil {
|
|
subscription.CurrentPeriodStart = *webhookEvent.PeriodStart
|
|
}
|
|
if webhookEvent.PeriodEnd != nil {
|
|
subscription.CurrentPeriodEnd = *webhookEvent.PeriodEnd
|
|
}
|
|
|
|
if err := tx.Save(&subscription).Error; err != nil {
|
|
logger.Error("failed to save subscription for provider sync", "error", err)
|
|
return err
|
|
}
|
|
|
|
eventType := billing_models.EventNewBillingCycleStarted
|
|
if oldStorageGB < webhookEvent.QuantityGB {
|
|
eventType = billing_models.EventUpgraded
|
|
} else if oldStorageGB > webhookEvent.QuantityGB {
|
|
eventType = billing_models.EventDowngraded
|
|
}
|
|
|
|
event := billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
Type: eventType,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &subscription.Status,
|
|
ProviderEventID: &webhookEvent.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}
|
|
|
|
if oldStorageGB != subscription.StorageGB {
|
|
event.OldStorageGB = &oldStorageGB
|
|
event.NewStorageGB = &subscription.StorageGB
|
|
}
|
|
|
|
if err := tx.Create(&event).Error; err != nil {
|
|
logger.Error("failed to create subscription event for provider sync", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"subscription synced from provider: %s -> %s, %d GB -> %d GB, period until %s",
|
|
string(oldStatus),
|
|
string(subscription.Status),
|
|
oldStorageGB,
|
|
subscription.StorageGB,
|
|
subscription.CurrentPeriodEnd.Format(time.RFC3339),
|
|
),
|
|
)
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) CancelSubscription(
|
|
logger *slog.Logger,
|
|
sub *billing_models.Subscription,
|
|
webhookEvent billing_models.WebhookEvent,
|
|
) error {
|
|
logger = logger.With("subscription_id", sub.ID, "database_id", sub.DatabaseID)
|
|
logger.Debug(fmt.Sprintf("handling subscription cancel (was %s)", string(sub.Status)))
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", sub.ID).First(&sub).Error; err != nil {
|
|
logger.Error("failed to lock subscription for cancel", "error", err)
|
|
return err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
oldStatus := sub.Status
|
|
sub.Status = billing_models.StatusCanceled
|
|
sub.CanceledAt = &now
|
|
sub.UpdatedAt = now
|
|
|
|
if oldStatus == billing_models.StatusPastDue {
|
|
// past_due -> canceled - immediae cancelation. User
|
|
// is not paying so cannot create new backups (but
|
|
// applying grace period to allow downloading\restore)
|
|
sub.CurrentPeriodEnd = now
|
|
retention := now.Add(config.GetEnv().GracePeriod)
|
|
sub.DataRetentionGracePeriodUntil = &retention
|
|
} else {
|
|
// subscription is active, but will be canceled in the
|
|
// end of the billing period. User allowed to do any
|
|
// actions. Grace period will be applied after end of the billing period,
|
|
// when subscription will be moved to expired status
|
|
retention := sub.CurrentPeriodEnd.Add(config.GetEnv().GracePeriod)
|
|
sub.DataRetentionGracePeriodUntil = &retention
|
|
}
|
|
|
|
if err := tx.Save(&sub).Error; err != nil {
|
|
logger.Error("failed to save subscription for cancel", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: sub.ID,
|
|
Type: billing_models.EventCanceled,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &sub.Status,
|
|
ProviderEventID: &webhookEvent.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for cancel", "error", err)
|
|
return err
|
|
}
|
|
|
|
if oldStatus == billing_models.StatusPastDue {
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"subscription canceled immediately (was from past_due), applying grace period until %s",
|
|
sub.DataRetentionGracePeriodUntil.Format(time.RFC3339),
|
|
),
|
|
)
|
|
} else {
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"subscription cancelation scheduled at the end of billing period %s, applying grace period until %s",
|
|
sub.CurrentPeriodEnd.Format(time.RFC3339),
|
|
sub.DataRetentionGracePeriodUntil.Format(time.RFC3339),
|
|
),
|
|
)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) ReactivateSubscription(
|
|
logger *slog.Logger,
|
|
sub *billing_models.Subscription,
|
|
webhookEvent billing_models.WebhookEvent,
|
|
) error {
|
|
logger = logger.With("subscription_id", sub.ID, "database_id", sub.DatabaseID)
|
|
logger.Debug("handling subscription reactivation (undo cancel)")
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", sub.ID).First(&sub).Error; err != nil {
|
|
logger.Error("failed to lock subscription for reactivation", "error", err)
|
|
return err
|
|
}
|
|
|
|
if sub.Status != billing_models.StatusCanceled {
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"subscription is no longer canceled (status: %s), skipping reactivation",
|
|
string(sub.Status),
|
|
),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
oldStatus := sub.Status
|
|
sub.Status = webhookEvent.Status
|
|
sub.CanceledAt = nil
|
|
sub.DataRetentionGracePeriodUntil = nil
|
|
sub.StorageGB = webhookEvent.QuantityGB
|
|
sub.PendingStorageGB = nil
|
|
sub.UpdatedAt = time.Now().UTC()
|
|
|
|
if webhookEvent.PeriodStart != nil {
|
|
sub.CurrentPeriodStart = *webhookEvent.PeriodStart
|
|
}
|
|
if webhookEvent.PeriodEnd != nil {
|
|
sub.CurrentPeriodEnd = *webhookEvent.PeriodEnd
|
|
}
|
|
|
|
if err := tx.Save(&sub).Error; err != nil {
|
|
logger.Error("failed to save subscription for reactivation", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: sub.ID,
|
|
Type: billing_models.EventReactivated,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &sub.Status,
|
|
ProviderEventID: &webhookEvent.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for reactivation", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(fmt.Sprintf("subscription reactivated: %s -> %s", string(oldStatus), string(sub.Status)))
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) GetPortalURL(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
subscriptionID uuid.UUID,
|
|
) (string, error) {
|
|
logger.Debug("getting billing portal URL")
|
|
|
|
subscription, err := s.getSubscriptionWithAccessCheck(logger, user, subscriptionID)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
logger = logger.With("database_id", subscription.DatabaseID)
|
|
|
|
if subscription.Status != billing_models.StatusActive &&
|
|
subscription.Status != billing_models.StatusPastDue &&
|
|
subscription.Status != billing_models.StatusCanceled {
|
|
logger.Error("subscription is not active", "status", subscription.Status)
|
|
return "", fmt.Errorf("subscription is not active, past due, or canceled")
|
|
}
|
|
|
|
returnURL := "https://app.databasus.com"
|
|
url, err := s.billingProvider.CreatePortalSession(logger, *subscription.ProviderCustomerID, returnURL)
|
|
if err != nil {
|
|
logger.Error("failed to create portal session", "error", err)
|
|
return "", err
|
|
}
|
|
|
|
logger.Debug("billing portal session created", "url", url)
|
|
|
|
return url, nil
|
|
}
|
|
|
|
func (s *BillingService) GetSubscriptionEvents(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
subscriptionID uuid.UUID,
|
|
limit, offset int,
|
|
) (*GetSubscriptionEventsResponse, error) {
|
|
subscription, err := s.getSubscriptionWithAccessCheck(logger, user, subscriptionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger = logger.With("database_id", subscription.DatabaseID)
|
|
|
|
limit = normalizePaginationLimit(limit)
|
|
offset = max(offset, 0)
|
|
|
|
events, err := s.subscriptionEventRepository.FindByDatabaseID(subscription.DatabaseID, limit, offset)
|
|
if err != nil {
|
|
logger.Error("failed to get subscription events", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
total, err := s.subscriptionEventRepository.CountByDatabaseID(subscription.DatabaseID)
|
|
if err != nil {
|
|
logger.Error("failed to count subscription events", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
return &GetSubscriptionEventsResponse{
|
|
Events: events,
|
|
Total: total,
|
|
Limit: limit,
|
|
Offset: offset,
|
|
}, nil
|
|
}
|
|
|
|
func (s *BillingService) GetSubscriptionInvoices(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
subscriptionID uuid.UUID,
|
|
limit, offset int,
|
|
) (*GetInvoicesResponse, error) {
|
|
subscription, err := s.getSubscriptionWithAccessCheck(logger, user, subscriptionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
logger = logger.With("database_id", subscription.DatabaseID)
|
|
|
|
limit = normalizePaginationLimit(limit)
|
|
offset = max(offset, 0)
|
|
|
|
invoices, err := s.invoiceRepository.FindByDatabaseID(subscription.DatabaseID, limit, offset)
|
|
if err != nil {
|
|
logger.Error("failed to get subscription invoices", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
total, err := s.invoiceRepository.CountByDatabaseID(subscription.DatabaseID)
|
|
if err != nil {
|
|
logger.Error("failed to count subscription invoices", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
return &GetInvoicesResponse{
|
|
Invoices: invoices,
|
|
Total: total,
|
|
Limit: limit,
|
|
Offset: offset,
|
|
}, nil
|
|
}
|
|
|
|
func normalizePaginationLimit(limit int) int {
|
|
if limit <= 0 || limit > 1000 {
|
|
return 100
|
|
}
|
|
|
|
return limit
|
|
}
|
|
|
|
func (s *BillingService) GetSubscriptionByDatabaseID(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
databaseID uuid.UUID,
|
|
) (*billing_models.Subscription, error) {
|
|
logger = logger.With("database_id", databaseID)
|
|
logger.Debug("getting subscription by database ID")
|
|
|
|
if err := s.validateDatabaseAccess(logger, user, databaseID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
subscription, err := s.GetSubscription(logger, databaseID)
|
|
if err != nil {
|
|
logger.Error("failed to get subscription", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
return subscription, nil
|
|
}
|
|
|
|
func (s *BillingService) RecordPaymentFailed(
|
|
logger *slog.Logger,
|
|
subscription *billing_models.Subscription,
|
|
event billing_models.WebhookEvent,
|
|
) error {
|
|
logger = logger.With(
|
|
"subscription_id", subscription.ID,
|
|
"database_id", subscription.DatabaseID,
|
|
"provider_invoice_id", event.ProviderInvoiceID,
|
|
)
|
|
logger.Debug("recording payment failure for subscription")
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", subscription.ID).First(&subscription).Error; err != nil {
|
|
logger.Error("failed to lock subscription for payment failure", "error", err)
|
|
return err
|
|
}
|
|
|
|
oldStatus := subscription.Status
|
|
subscription.Status = billing_models.StatusPastDue
|
|
subscription.UpdatedAt = time.Now().UTC()
|
|
|
|
if err := tx.Save(&subscription).Error; err != nil {
|
|
logger.Error("failed to save subscription for payment failure", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
Type: billing_models.EventPastDue,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &subscription.Status,
|
|
ProviderEventID: &event.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for payment failure", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(fmt.Sprintf("subscription marked as past_due due to payment failure (was %s)", string(oldStatus)))
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) RecordDispute(logger *slog.Logger, event billing_models.WebhookEvent) error {
|
|
logger = logger.With("provider_invoice_id", event.ProviderInvoiceID)
|
|
logger.Debug("recording dispute for subscription")
|
|
|
|
invoice, err := s.invoiceRepository.FindByProviderInvID(event.ProviderInvoiceID)
|
|
if err != nil {
|
|
logger.Error("failed to find invoice for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
subscription, err := s.subscriptionRepository.FindByID(invoice.SubscriptionID)
|
|
if err != nil {
|
|
logger.Error("failed to find subscription for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger = logger.With(
|
|
"subscription_id", subscription.ID,
|
|
"database_id", subscription.DatabaseID,
|
|
"invoice_id", invoice.ID,
|
|
)
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", subscription.ID).First(&subscription).Error; err != nil {
|
|
logger.Error("failed to lock subscription for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
oldStatus := subscription.Status
|
|
subscription.Status = billing_models.StatusCanceled
|
|
subscription.UpdatedAt = time.Now().UTC()
|
|
|
|
// Dispute does not have grace period. We provide grace period
|
|
// for canceled subscriptions and accidentally expired
|
|
|
|
now := time.Now().UTC()
|
|
subscription.CanceledAt = &now
|
|
subscription.UpdatedAt = now
|
|
|
|
if err := tx.Save(&subscription).Error; err != nil {
|
|
logger.Error("failed to save subscription for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
invoice.Status = billing_models.InvoiceStatusDisputed
|
|
if err := tx.Save(&invoice).Error; err != nil {
|
|
logger.Error("failed to save invoice for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
Type: billing_models.EventDispute,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &subscription.Status,
|
|
ProviderEventID: &event.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for dispute", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(fmt.Sprintf("subscription canceled due to dispute (was %s)", string(oldStatus)))
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) ChangeSubscriptionStorage(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
databaseID uuid.UUID,
|
|
newStorageGB int,
|
|
) (*ChangeStorageResult, error) {
|
|
if err := s.validateDatabaseAccess(logger, user, databaseID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if newStorageGB < config.GetEnv().MinStorageGB || newStorageGB > config.GetEnv().MaxStorageGB {
|
|
logger.Error(
|
|
fmt.Sprintf(
|
|
"invalid storage requested for change: %d GB (allowed %d - %d)",
|
|
newStorageGB,
|
|
config.GetEnv().MinStorageGB,
|
|
config.GetEnv().MaxStorageGB,
|
|
),
|
|
)
|
|
return nil, ErrInvalidStorage
|
|
}
|
|
|
|
activeSub, err := s.getActiveSubscription(logger, databaseID)
|
|
if err != nil {
|
|
logger.Error("failed to find active subscription for storage change", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("changing subscription storage to %d GB", newStorageGB))
|
|
logger = logger.With("subscription_id", activeSub.ID)
|
|
|
|
if newStorageGB == activeSub.StorageGB {
|
|
logger.Warn("requested storage is the same as current")
|
|
return nil, ErrNoChange
|
|
}
|
|
|
|
if newStorageGB > activeSub.StorageGB {
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"requested storage is greater than current, applying upgrade: %d GB -> %d GB",
|
|
activeSub.StorageGB,
|
|
newStorageGB,
|
|
),
|
|
)
|
|
return s.applyUpgrade(logger, activeSub, newStorageGB)
|
|
} else {
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"requested storage is less than current, applying downgrade: %d GB -> %d GB",
|
|
activeSub.StorageGB,
|
|
newStorageGB,
|
|
),
|
|
)
|
|
return s.applyDowngrade(logger, activeSub, newStorageGB)
|
|
}
|
|
}
|
|
|
|
func (s *BillingService) ActivateSubscription(logger *slog.Logger, webhookEvent billing_models.WebhookEvent) error {
|
|
logger.Debug("handling subscription created")
|
|
|
|
databaseID := webhookEvent.DatabaseID
|
|
if databaseID == nil {
|
|
logger.Error("database ID is missing in webhook event")
|
|
return fmt.Errorf("database ID is missing in webhook event")
|
|
}
|
|
|
|
logger = logger.With("database_id", *databaseID)
|
|
|
|
existingSubscription, err := s.subscriptionRepository.FindByProviderSubID(webhookEvent.ProviderSubscriptionID)
|
|
if err == nil && existingSubscription != nil {
|
|
logger.Warn("subscription already existing, idempotent skip", "subscription_id", existingSubscription.ID)
|
|
return nil
|
|
}
|
|
|
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
|
logger.Error("failed to check existing subscription for activation", "error", err)
|
|
return err
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
providerName := string(s.billingProvider.GetProviderName())
|
|
|
|
subscription := billing_models.Subscription{
|
|
ID: uuid.New(),
|
|
DatabaseID: *databaseID,
|
|
Status: billing_models.StatusActive,
|
|
StorageGB: webhookEvent.QuantityGB,
|
|
CurrentPeriodStart: *webhookEvent.PeriodStart,
|
|
CurrentPeriodEnd: *webhookEvent.PeriodEnd,
|
|
ProviderName: &providerName,
|
|
ProviderSubID: &webhookEvent.ProviderSubscriptionID,
|
|
ProviderCustomerID: &webhookEvent.ProviderCustomerID,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
logger = logger.With(
|
|
"subscription_id", subscription.ID,
|
|
"provider_subscription_id", webhookEvent.ProviderSubscriptionID,
|
|
"provider_customer_id", webhookEvent.ProviderCustomerID,
|
|
)
|
|
|
|
if err := storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
// expire any existing trial subscription for this database
|
|
trialSubs, findErr := s.subscriptionRepository.FindByDatabaseIDAndStatuses(
|
|
*databaseID,
|
|
[]billing_models.SubscriptionStatus{billing_models.StatusTrial},
|
|
)
|
|
if findErr != nil {
|
|
logger.Error("failed to find trial subscriptions", "error", findErr)
|
|
return findErr
|
|
}
|
|
|
|
for _, trialSub := range trialSubs {
|
|
now := time.Now().UTC()
|
|
trialSub.Status = billing_models.StatusExpired
|
|
trialSub.CanceledAt = &now
|
|
trialSub.UpdatedAt = now
|
|
|
|
if err := tx.Save(trialSub).Error; err != nil {
|
|
logger.Error(
|
|
"failed to expire trial subscription during activation",
|
|
"error",
|
|
err,
|
|
"subscription_id",
|
|
trialSub.ID,
|
|
)
|
|
return err
|
|
}
|
|
|
|
logger.Info("expired trial subscription during paid activation", "trial_subscription_id", trialSub.ID)
|
|
}
|
|
|
|
if err := tx.Create(&subscription).Error; err != nil {
|
|
logger.Error("failed to create subscription for activation", "error", err)
|
|
return err
|
|
}
|
|
|
|
newStatus := subscription.Status
|
|
return tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
Type: billing_models.EventCreated,
|
|
NewStorageGB: &subscription.StorageGB,
|
|
NewStatus: &newStatus,
|
|
ProviderEventID: &webhookEvent.ProviderEventID,
|
|
}).Error
|
|
}); err != nil {
|
|
logger.Error("failed to activate subscription", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info("subscription activated", "subscription_id", subscription.ID)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) RecordPaymentSuccess(
|
|
logger *slog.Logger,
|
|
subscription *billing_models.Subscription,
|
|
webhookEvent billing_models.WebhookEvent,
|
|
) error {
|
|
logger = logger.With(
|
|
"subscription_id", subscription.ID,
|
|
"database_id", subscription.DatabaseID,
|
|
"provider_invoice_id", webhookEvent.ProviderInvoiceID,
|
|
)
|
|
logger.Debug("recording payment success for subscription")
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", subscription.ID).First(&subscription).Error; err != nil {
|
|
logger.Error("failed to lock subscription for payment success", "error", err)
|
|
return err
|
|
}
|
|
|
|
// if was past_due -> move back to active
|
|
if subscription.Status == billing_models.StatusPastDue {
|
|
oldStatus := subscription.Status
|
|
subscription.Status = billing_models.StatusActive
|
|
subscription.UpdatedAt = time.Now().UTC()
|
|
|
|
if err := tx.Save(&subscription).Error; err != nil {
|
|
logger.Error("failed to save subscription for payment success", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
Type: billing_models.EventRecoveredFromPastDue,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &subscription.Status,
|
|
ProviderEventID: &webhookEvent.ProviderEventID,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for payment recovery", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"subscription recovered from past_due due to successful payment (was %s)",
|
|
string(oldStatus),
|
|
),
|
|
)
|
|
}
|
|
|
|
// check if invoice already exists (idempotency guard)
|
|
var existingInvoice billing_models.Invoice
|
|
invoiceLookupErr := tx.Where("provider_invoice_id = ?", webhookEvent.ProviderInvoiceID).
|
|
First(&existingInvoice).Error
|
|
|
|
if invoiceLookupErr == nil {
|
|
logger.Info("invoice already exists, idempotent skip",
|
|
"provider_invoice_id", webhookEvent.ProviderInvoiceID,
|
|
"existing_invoice_id", existingInvoice.ID,
|
|
)
|
|
return nil
|
|
}
|
|
|
|
if !errors.Is(invoiceLookupErr, gorm.ErrRecordNotFound) {
|
|
logger.Error("failed to check existing invoice", "error", invoiceLookupErr)
|
|
return invoiceLookupErr
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
invoice := billing_models.Invoice{
|
|
ID: uuid.New(),
|
|
SubscriptionID: subscription.ID,
|
|
ProviderInvoiceID: webhookEvent.ProviderInvoiceID,
|
|
AmountCents: webhookEvent.AmountCents,
|
|
Status: billing_models.InvoiceStatusPaid,
|
|
StorageGB: webhookEvent.QuantityGB,
|
|
PeriodStart: *webhookEvent.PeriodStart,
|
|
PeriodEnd: *webhookEvent.PeriodEnd,
|
|
PaidAt: &now,
|
|
}
|
|
if err := tx.Create(&invoice).Error; err != nil {
|
|
logger.Error("failed to create invoice for payment success", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(
|
|
fmt.Sprintf("invoice recorded: %d cents USD for %d GB", webhookEvent.AmountCents, webhookEvent.QuantityGB),
|
|
)
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) GetSubscriptionByProviderSubID(
|
|
logger *slog.Logger,
|
|
providerSubID string,
|
|
) (*billing_models.Subscription, error) {
|
|
logger = logger.With("provider_subscription_id", providerSubID)
|
|
logger.Debug("getting subscription by provider subscription ID")
|
|
|
|
subscription, err := s.subscriptionRepository.FindByProviderSubID(providerSubID)
|
|
if err != nil {
|
|
logger.Error("failed to find subscription by provider subscription ID", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
if subscription == nil {
|
|
return nil, ErrSubscriptionNotFound
|
|
}
|
|
|
|
return subscription, nil
|
|
}
|
|
|
|
func (s *BillingService) GetSubscription(
|
|
logger *slog.Logger,
|
|
databaseID uuid.UUID,
|
|
) (*billing_models.Subscription, error) {
|
|
subscription, err := s.subscriptionRepository.FindLatestByDatabaseID(databaseID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if subscription == nil {
|
|
return nil, ErrSubscriptionNotFound
|
|
}
|
|
|
|
return subscription, nil
|
|
}
|
|
|
|
func (s *BillingService) getActiveSubscription(
|
|
logger *slog.Logger,
|
|
databaseID uuid.UUID,
|
|
) (*billing_models.Subscription, error) {
|
|
activeSubs, err := s.subscriptionRepository.FindByDatabaseIDAndStatuses(
|
|
databaseID,
|
|
[]billing_models.SubscriptionStatus{
|
|
billing_models.StatusActive,
|
|
billing_models.StatusTrial,
|
|
billing_models.StatusPastDue,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(activeSubs) == 0 {
|
|
return nil, ErrNoActiveSubscription
|
|
}
|
|
|
|
if len(activeSubs) > 1 {
|
|
logger.Error(fmt.Sprintf("multiple active subscriptions found: %d", len(activeSubs)))
|
|
}
|
|
|
|
return activeSubs[0], nil
|
|
}
|
|
|
|
func (s *BillingService) reconcileSubscriptions(logger *slog.Logger) error {
|
|
logger.Debug("starting subscription reconciliation")
|
|
|
|
subscriptions, err := s.subscriptionRepository.FindByStatuses([]billing_models.SubscriptionStatus{
|
|
billing_models.StatusActive,
|
|
billing_models.StatusPastDue,
|
|
})
|
|
if err != nil {
|
|
logger.Error("failed to find subscriptions for reconciliation", "error", err)
|
|
return err
|
|
}
|
|
|
|
for _, subscription := range subscriptions {
|
|
scopedLog := logger.With(
|
|
"subscription_id", subscription.ID,
|
|
"database_id", subscription.DatabaseID,
|
|
"provider_subscription_id", subscription.ProviderSubID,
|
|
)
|
|
|
|
providerSubscription, err := s.billingProvider.GetSubscription(scopedLog, *subscription.ProviderSubID)
|
|
if err != nil {
|
|
scopedLog.Error("failed to get subscription from billing provider during reconciliation", "error", err)
|
|
continue
|
|
}
|
|
|
|
if subscription.Status != providerSubscription.Status {
|
|
scopedLog.Error(
|
|
fmt.Sprintf(
|
|
"subscription status mismatch with billing provider, local: %s, provider: %s",
|
|
subscription.Status,
|
|
providerSubscription.Status,
|
|
),
|
|
)
|
|
continue
|
|
}
|
|
|
|
if subscription.StorageGB != providerSubscription.QuantityGB {
|
|
scopedLog.Error(
|
|
fmt.Sprintf(
|
|
"subscription storage mismatch with billing provider, local: %d GB, provider: %d GB",
|
|
subscription.StorageGB,
|
|
providerSubscription.QuantityGB,
|
|
),
|
|
)
|
|
continue
|
|
}
|
|
}
|
|
|
|
logger.Debug("subscription reconciliation completed")
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) processExpiredSubscriptions(logger *slog.Logger) error {
|
|
logger.Debug("started expiring subscriptions processing")
|
|
|
|
subsWithEndedGracePeriod, err := s.subscriptionRepository.FindCanceledWithEndedGracePeriod(time.Now().UTC())
|
|
if err != nil {
|
|
logger.Error("failed to find canceled subscriptions with ended grace period", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("found %d canceled subscriptions past retention", len(subsWithEndedGracePeriod)))
|
|
|
|
for _, subscription := range subsWithEndedGracePeriod {
|
|
scopedLog := logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
|
|
err = s.expireSubscription(scopedLog, &subscription)
|
|
if err != nil {
|
|
scopedLog.Error("failed to expire subscription", "error", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) expireSubscription(logger *slog.Logger, sub *billing_models.Subscription) error {
|
|
logger = logger.With("subscription_id", sub.ID, "database_id", sub.DatabaseID)
|
|
logger.Debug(fmt.Sprintf("expiring subscription (was %s)", string(sub.Status)))
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", sub.ID).First(&sub).Error; err != nil {
|
|
logger.Error("failed to lock subscription for expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
oldStatus := sub.Status
|
|
sub.Status = billing_models.StatusExpired
|
|
sub.UpdatedAt = time.Now().UTC()
|
|
|
|
if err := tx.Save(&sub).Error; err != nil {
|
|
logger.Error("failed to save subscription for expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: sub.ID,
|
|
Type: billing_models.EventExpired,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &sub.Status,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(fmt.Sprintf("subscription expired (was %s)", string(oldStatus)))
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) processExpiredTrials(logger *slog.Logger) error {
|
|
logger.Debug("started expiring trial subscriptions processing")
|
|
|
|
trials, err := s.subscriptionRepository.FindExpiredTrials(time.Now())
|
|
if err != nil {
|
|
logger.Error("failed to find expired trial subscriptions", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Debug(fmt.Sprintf("found %d expired trial subscriptions", len(trials)))
|
|
|
|
for _, subscription := range trials {
|
|
scopedLog := logger.With("subscription_id", subscription.ID, "database_id", subscription.DatabaseID)
|
|
err = s.expireTrialSubscription(scopedLog, &subscription)
|
|
if err != nil {
|
|
scopedLog.Error("failed to expire trial subscription", "error", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) expireTrialSubscription(logger *slog.Logger, sub *billing_models.Subscription) error {
|
|
logger = logger.With("subscription_id", sub.ID, "database_id", sub.DatabaseID)
|
|
logger.Debug("expiring trial subscription")
|
|
|
|
return storage.GetDb().Transaction(func(tx *gorm.DB) error {
|
|
if err := tx.Clauses(clause.Locking{Strength: "UPDATE"}).
|
|
Where("id = ?", sub.ID).First(&sub).Error; err != nil {
|
|
logger.Error("failed to lock subscription for trial expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
oldStatus := sub.Status
|
|
sub.Status = billing_models.StatusExpired
|
|
sub.UpdatedAt = time.Now().UTC()
|
|
|
|
now := time.Now().UTC()
|
|
sub.CanceledAt = &now
|
|
sub.UpdatedAt = now
|
|
|
|
if err := tx.Save(&sub).Error; err != nil {
|
|
logger.Error("failed to save subscription for trial expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
if err := tx.Create(&billing_models.SubscriptionEvent{
|
|
ID: uuid.New(),
|
|
SubscriptionID: sub.ID,
|
|
Type: billing_models.EventExpired,
|
|
OldStatus: &oldStatus,
|
|
NewStatus: &sub.Status,
|
|
CreatedAt: time.Now().UTC(),
|
|
}).Error; err != nil {
|
|
logger.Error("failed to create subscription event for trial expire", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(fmt.Sprintf("trial subscription expired (was %s)", string(oldStatus)))
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *BillingService) createTrialSubscription(logger *slog.Logger, databaseID uuid.UUID) error {
|
|
logger = logger.With("database_id", databaseID)
|
|
|
|
dbCreatedAt := time.Now().UTC()
|
|
trialEnds := dbCreatedAt.Add(config.GetEnv().TrialDuration)
|
|
|
|
logger.Debug(
|
|
fmt.Sprintf(
|
|
"creating trial subscription: %d GB, expires %s",
|
|
config.GetEnv().TrialStorageGB,
|
|
trialEnds.Format(time.RFC3339),
|
|
),
|
|
)
|
|
|
|
sub := billing_models.Subscription{
|
|
ID: uuid.New(),
|
|
DatabaseID: databaseID,
|
|
Status: billing_models.StatusTrial,
|
|
StorageGB: config.GetEnv().TrialStorageGB,
|
|
CurrentPeriodStart: dbCreatedAt,
|
|
CurrentPeriodEnd: trialEnds,
|
|
CreatedAt: time.Now().UTC(),
|
|
UpdatedAt: time.Now().UTC(),
|
|
}
|
|
|
|
logger = logger.With("subscription_id", sub.ID)
|
|
|
|
if err := s.subscriptionRepository.Save(sub); err != nil {
|
|
logger.Error("failed to save trial subscription", "error", err)
|
|
return err
|
|
}
|
|
|
|
logger.Info(
|
|
fmt.Sprintf(
|
|
"trial subscription created: %d GB, expires %s",
|
|
config.GetEnv().TrialStorageGB,
|
|
trialEnds.Format(time.RFC3339),
|
|
),
|
|
)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) applyUpgrade(
|
|
logger *slog.Logger,
|
|
sub *billing_models.Subscription,
|
|
newStorageGB int,
|
|
) (*ChangeStorageResult, error) {
|
|
logger.Debug(fmt.Sprintf("applying upgrade for subscription: new storage %d GB", newStorageGB))
|
|
|
|
err := s.billingProvider.UpgradeQuantityWithSurcharge(logger, *sub.ProviderSubID, newStorageGB)
|
|
if err != nil {
|
|
logger.Error("failed to apply upgrade with billing provider", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
logger.Debug("upgrade requested, waiting for billing provider webhook to update subscription")
|
|
|
|
return &ChangeStorageResult{
|
|
ApplyMode: ChangeStorageApplyImmediate,
|
|
CurrentGB: sub.StorageGB,
|
|
PendingGB: &newStorageGB,
|
|
}, nil
|
|
}
|
|
|
|
func (s *BillingService) applyDowngrade(
|
|
logger *slog.Logger,
|
|
sub *billing_models.Subscription,
|
|
newStorageGB int,
|
|
) (*ChangeStorageResult, error) {
|
|
logger.Debug(fmt.Sprintf("applying downgrade for subscription: new storage %d GB", newStorageGB))
|
|
|
|
err := s.billingProvider.ScheduleQuantityDowngradeFromNextBillingCycle(logger, *sub.ProviderSubID, newStorageGB)
|
|
if err != nil {
|
|
logger.Error("failed to schedule downgrade with billing provider", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
sub.PendingStorageGB = &newStorageGB
|
|
sub.UpdatedAt = time.Now().UTC()
|
|
if err := s.subscriptionRepository.Save(*sub); err != nil {
|
|
logger.Error("failed to save subscription with pending downgrade", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
logger.Debug("downgrade scheduled for next billing cycle")
|
|
|
|
return &ChangeStorageResult{
|
|
ApplyMode: ChangeStorageApplyNextCycle,
|
|
CurrentGB: sub.StorageGB,
|
|
PendingGB: &newStorageGB,
|
|
}, nil
|
|
}
|
|
|
|
func (s *BillingService) validateDatabaseAccess(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
databaseID uuid.UUID,
|
|
) error {
|
|
database, err := s.databaseService.GetDatabaseByID(databaseID)
|
|
if err != nil {
|
|
logger.Error("failed to get database", "error", err)
|
|
return err
|
|
}
|
|
|
|
hasAccess, _, err := s.workspaceService.CanUserAccessWorkspace(*database.WorkspaceID, user)
|
|
if err != nil {
|
|
logger.Error("failed to check workspace access", "error", err)
|
|
return err
|
|
}
|
|
|
|
if !hasAccess {
|
|
logger.Error("user does not have access to the workspace")
|
|
return ErrAccessDenied
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *BillingService) getSubscriptionWithAccessCheck(
|
|
logger *slog.Logger,
|
|
user *users_models.User,
|
|
subscriptionID uuid.UUID,
|
|
) (*billing_models.Subscription, error) {
|
|
subscription, err := s.subscriptionRepository.FindByID(subscriptionID)
|
|
if err != nil {
|
|
logger.Error("failed to find subscription", "error", err)
|
|
return nil, err
|
|
}
|
|
|
|
if subscription == nil {
|
|
logger.Error("subscription not found")
|
|
return nil, ErrSubscriptionNotFound
|
|
}
|
|
|
|
if err := s.validateDatabaseAccess(logger, user, subscription.DatabaseID); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return subscription, nil
|
|
}
|