diff --git a/web/src/padelnomics/billing/stripe.py b/web/src/padelnomics/billing/stripe.py new file mode 100644 index 0000000..9ef719e --- /dev/null +++ b/web/src/padelnomics/billing/stripe.py @@ -0,0 +1,329 @@ +""" +Stripe payment provider — checkout sessions, webhook handling, subscription management. + +Exports the same interface as paddle.py so billing/routes.py can dispatch: +- build_checkout_payload() +- build_multi_item_checkout_payload() +- cancel_subscription() +- get_management_url() +- verify_webhook() +- parse_webhook() + +Stripe Tax add-on handles EU VAT collection (must be enabled in Stripe Dashboard). +""" + +import hashlib +import hmac +import json +import logging +import time + +import stripe as stripe_sdk + +from ..core import config + +logger = logging.getLogger(__name__) + +# Timeout for all Stripe API calls (seconds) +_STRIPE_TIMEOUT_SECONDS = 10 + + +def _stripe_client(): + """Configure and return the stripe module with our API key.""" + stripe_sdk.api_key = config.STRIPE_SECRET_KEY + stripe_sdk.max_network_retries = 2 + return stripe_sdk + + +def build_checkout_payload( + price_id: str, custom_data: dict, success_url: str, +) -> dict: + """Create a Stripe Checkout Session for a single item. + + Returns {checkout_url: "https://checkout.stripe.com/..."} — the client + JS redirects the browser there (no overlay SDK needed). + """ + s = _stripe_client() + session = s.checkout.Session.create( + mode=_mode_for_price(s, price_id), + line_items=[{"price": price_id, "quantity": 1}], + metadata=custom_data, + success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}", + cancel_url=success_url.rsplit("/success", 1)[0] + "/pricing", + automatic_tax={"enabled": True}, + tax_id_collection={"enabled": True}, + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + return {"checkout_url": session.url} + + +def build_multi_item_checkout_payload( + items: list[dict], custom_data: dict, success_url: str, +) -> dict: + """Create a Stripe Checkout Session for multiple line items. + + items: list of {"priceId": "price_xxx", "quantity": 1} + """ + s = _stripe_client() + + line_items = [{"price": i["priceId"], "quantity": i.get("quantity", 1)} for i in items] + + # Determine mode: if any item is recurring, use "subscription". + # Otherwise use "payment" for one-time purchases. + has_recurring = any(_is_recurring_price(s, i["priceId"]) for i in items) + mode = "subscription" if has_recurring else "payment" + + session = s.checkout.Session.create( + mode=mode, + line_items=line_items, + metadata=custom_data, + success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}", + cancel_url=success_url.rsplit("/success", 1)[0], + automatic_tax={"enabled": True}, + tax_id_collection={"enabled": True}, + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + return {"checkout_url": session.url} + + +def _mode_for_price(s, price_id: str) -> str: + """Determine Checkout Session mode from price type.""" + try: + price = s.Price.retrieve(price_id, request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}) + return "subscription" if price.type == "recurring" else "payment" + except Exception: + # Default to payment if we can't determine + return "payment" + + +def _is_recurring_price(s, price_id: str) -> bool: + """Check if a Stripe price is recurring (subscription).""" + try: + price = s.Price.retrieve(price_id, request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}) + return price.type == "recurring" + except Exception: + return False + + +def cancel_subscription(provider_subscription_id: str) -> None: + """Cancel a Stripe subscription at end of current billing period.""" + s = _stripe_client() + s.Subscription.modify( + provider_subscription_id, + cancel_at_period_end=True, + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + + +def get_management_url(provider_subscription_id: str) -> str: + """Create a Stripe Billing Portal session and return its URL.""" + s = _stripe_client() + + # Get customer_id from the subscription + sub = s.Subscription.retrieve( + provider_subscription_id, + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + portal = s.billing_portal.Session.create( + customer=sub.customer, + return_url=f"{config.BASE_URL}/billing/success", + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + return portal.url + + +def verify_webhook(payload: bytes, headers) -> bool: + """Verify Stripe webhook signature using the Stripe-Signature header.""" + if not config.STRIPE_WEBHOOK_SECRET: + return True + sig_header = headers.get("Stripe-Signature", "") + if not sig_header: + return False + try: + stripe_sdk.Webhook.construct_event( + payload, sig_header, config.STRIPE_WEBHOOK_SECRET, + ) + return True + except (stripe_sdk.SignatureVerificationError, ValueError): + return False + + +def parse_webhook(payload: bytes) -> dict: + """Parse a Stripe webhook payload into a normalized event dict. + + Maps Stripe event types to the shared format used by _handle_webhook_event(): + - checkout.session.completed (mode=subscription) → subscription.activated + - customer.subscription.updated → subscription.updated + - customer.subscription.deleted → subscription.canceled + - invoice.payment_failed → subscription.past_due + - checkout.session.completed (mode=payment) → transaction.completed + """ + raw = json.loads(payload) + stripe_type = raw.get("type", "") + obj = raw.get("data", {}).get("object", {}) + + # Extract metadata — Stripe stores custom data in session/subscription metadata + metadata = obj.get("metadata") or {} + + # Common fields + customer_id = obj.get("customer", "") + user_id = metadata.get("user_id") + supplier_id = metadata.get("supplier_id") + plan = metadata.get("plan", "") + + # Map Stripe events to our shared event types + if stripe_type == "checkout.session.completed": + mode = obj.get("mode", "") + if mode == "subscription": + subscription_id = obj.get("subscription", "") + # Fetch subscription details for period end + period_end = None + if subscription_id: + try: + s = _stripe_client() + sub = s.Subscription.retrieve( + subscription_id, + request_options={"timeout": _STRIPE_TIMEOUT_SECONDS}, + ) + period_end = _unix_to_iso(sub.current_period_end) + except Exception: + logger.warning("Failed to fetch subscription %s for period_end", subscription_id) + + return { + "event_type": "subscription.activated", + "subscription_id": subscription_id, + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": "active", + "current_period_end": period_end, + "data": obj, + "items": _extract_line_items(obj), + "custom_data": metadata, + } + else: + # One-time payment + return { + "event_type": "transaction.completed", + "subscription_id": "", + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": "completed", + "current_period_end": None, + "data": obj, + "items": _extract_line_items(obj), + "custom_data": metadata, + } + + elif stripe_type == "customer.subscription.updated": + status = _map_stripe_status(obj.get("status", "")) + return { + "event_type": "subscription.updated", + "subscription_id": obj.get("id", ""), + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": status, + "current_period_end": _unix_to_iso(obj.get("current_period_end")), + "data": obj, + "items": _extract_sub_items(obj), + "custom_data": metadata, + } + + elif stripe_type == "customer.subscription.deleted": + return { + "event_type": "subscription.canceled", + "subscription_id": obj.get("id", ""), + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": "cancelled", + "current_period_end": _unix_to_iso(obj.get("current_period_end")), + "data": obj, + "items": _extract_sub_items(obj), + "custom_data": metadata, + } + + elif stripe_type == "invoice.payment_failed": + sub_id = obj.get("subscription", "") + return { + "event_type": "subscription.past_due", + "subscription_id": sub_id, + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": "past_due", + "current_period_end": None, + "data": obj, + "items": [], + "custom_data": metadata, + } + + # Unknown event — return a no-op + return { + "event_type": "", + "subscription_id": "", + "customer_id": str(customer_id), + "user_id": user_id, + "supplier_id": supplier_id, + "plan": plan, + "status": "", + "current_period_end": None, + "data": obj, + "items": [], + "custom_data": metadata, + } + + +# ============================================================================= +# Helpers +# ============================================================================= + +def _map_stripe_status(stripe_status: str) -> str: + """Map Stripe subscription status to our internal status.""" + mapping = { + "active": "active", + "trialing": "on_trial", + "past_due": "past_due", + "canceled": "cancelled", + "unpaid": "past_due", + "incomplete": "past_due", + "incomplete_expired": "expired", + "paused": "paused", + } + return mapping.get(stripe_status, stripe_status) + + +def _unix_to_iso(ts) -> str | None: + """Convert Unix timestamp to ISO string, or None.""" + if not ts: + return None + from datetime import UTC, datetime + return datetime.fromtimestamp(int(ts), tz=UTC).strftime("%Y-%m-%dT%H:%M:%S.000000Z") + + +def _extract_line_items(session_obj: dict) -> list[dict]: + """Extract line items from a Checkout Session in Paddle-compatible format. + + Stripe sessions don't embed line items directly — we'd need an extra API call. + For webhook handling, the key info (price_id) comes from subscription items. + Returns items in the format: [{"price": {"id": "price_xxx"}}] + """ + # For checkout.session.completed, line_items aren't in the webhook payload. + # The webhook handler for subscription.activated fetches them separately. + # For one-time payments, we can reconstruct from the session's line_items + # via the Stripe API, but to keep webhook handling fast we skip this and + # handle it via the subscription events instead. + return [] + + +def _extract_sub_items(sub_obj: dict) -> list[dict]: + """Extract items from a Stripe Subscription object in Paddle-compatible format.""" + items = sub_obj.get("items", {}).get("data", []) + return [{"price": {"id": item.get("price", {}).get("id", "")}} for item in items]