fix(billing): add missing helper functions and fix upsert_subscription signature

- Add upsert_billing_customer / get_billing_customer (billing_customers table)
- Add record_transaction (idempotent on provider_transaction_id)
- Fix upsert_subscription: remove provider_customer_id param, key by
  provider_subscription_id instead of user_id (allows multi-sub)
- Update webhook handler to call upsert_billing_customer separately

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Deeman
2026-02-27 14:43:14 +01:00
parent 3faa29d8e5
commit a5d2a61cfb

View File

@@ -42,44 +42,74 @@ async def get_subscription(user_id: int) -> dict | None:
)
async def upsert_billing_customer(user_id: int, provider_customer_id: str) -> None:
"""Create or update billing customer record."""
await execute(
"""INSERT INTO billing_customers (user_id, provider_customer_id)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""",
(user_id, provider_customer_id),
)
async def get_billing_customer(user_id: int) -> dict | None:
"""Get billing customer record for a user."""
return await fetch_one(
"SELECT * FROM billing_customers WHERE user_id = ?",
(user_id,),
)
async def upsert_subscription(
user_id: int,
plan: str,
status: str,
provider_customer_id: str,
provider_subscription_id: str,
current_period_end: str = None,
) -> int:
"""Create or update subscription."""
"""Create or update subscription, keyed by provider_subscription_id."""
now = datetime.utcnow().isoformat()
customer_col = "paddle_customer_id" # legacy column, kept for existing rows
subscription_col = "provider_subscription_id"
existing = await fetch_one("SELECT id FROM subscriptions WHERE user_id = ?", (user_id,))
existing = await fetch_one(
"SELECT id FROM subscriptions WHERE provider_subscription_id = ?",
(provider_subscription_id,),
)
if existing:
await execute(
f"""UPDATE subscriptions
SET plan = ?, status = ?, {customer_col} = ?, {subscription_col} = ?,
current_period_end = ?, updated_at = ?
WHERE user_id = ?""",
(plan, status, provider_customer_id, provider_subscription_id,
current_period_end, now, user_id),
"""UPDATE subscriptions
SET plan = ?, status = ?, current_period_end = ?, updated_at = ?
WHERE provider_subscription_id = ?""",
(plan, status, current_period_end, now, provider_subscription_id),
)
return existing["id"]
else:
return await execute(
f"""INSERT INTO subscriptions
(user_id, plan, status, {customer_col}, {subscription_col},
current_period_end, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(user_id, plan, status, provider_customer_id, provider_subscription_id,
current_period_end, now, now),
"""INSERT INTO subscriptions
(user_id, plan, status, provider_subscription_id, current_period_end, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(user_id, plan, status, provider_subscription_id, current_period_end, now, now),
)
async def record_transaction(
user_id: int,
provider_transaction_id: str,
type: str = "payment",
amount_cents: int = None,
currency: str = "USD",
status: str = "pending",
) -> int:
"""Record a billing transaction. Idempotent on provider_transaction_id."""
now = datetime.utcnow().isoformat()
return await execute(
"""INSERT OR IGNORE INTO transactions
(user_id, provider_transaction_id, type, amount_cents, currency, status, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(user_id, provider_transaction_id, type, amount_cents, currency, status, now),
)
async def get_subscription_by_provider_id(subscription_id: str) -> dict | None:
return await fetch_one(
@@ -251,11 +281,14 @@ async def webhook():
if event_type == "subscription.activated":
plan = custom_data.get("plan", "starter")
uid = int(user_id) if user_id else 0
customer_id = data.get("customer_id")
if uid and customer_id:
await upsert_billing_customer(uid, str(customer_id))
await upsert_subscription(
user_id=int(user_id) if user_id else 0,
user_id=uid,
plan=plan,
status="active",
provider_customer_id=str(data.get("customer_id", "")),
provider_subscription_id=data.get("id", ""),
current_period_end=data.get("current_billing_period", {}).get("ends_at"),
)