""" Stripe payment provider — checkout sessions, webhook handling, subscription management. Exports the same interface as paddle.py so billing/routes.py can dispatch: - build_checkout_payload() - build_multi_item_checkout_payload() - cancel_subscription() - get_management_url() - verify_webhook() - parse_webhook() Stripe Tax add-on handles EU VAT collection (must be enabled in Stripe Dashboard). """ import json import logging import stripe as stripe_sdk from ..core import config logger = logging.getLogger(__name__) def _stripe_client(): """Configure and return the stripe module with our API key.""" stripe_sdk.api_key = config.STRIPE_SECRET_KEY stripe_sdk.max_network_retries = 2 return stripe_sdk def build_checkout_payload( price_id: str, custom_data: dict, success_url: str, ) -> dict: """Create a Stripe Checkout Session for a single item. Returns {checkout_url: "https://checkout.stripe.com/..."} — the client JS redirects the browser there (no overlay SDK needed). """ s = _stripe_client() session = s.checkout.Session.create( mode=_mode_for_price(s, price_id), line_items=[{"price": price_id, "quantity": 1}], metadata=custom_data, success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}", cancel_url=success_url.rsplit("/success", 1)[0] + "/pricing", automatic_tax={"enabled": True}, tax_id_collection={"enabled": True}, ) return {"checkout_url": session.url} def build_multi_item_checkout_payload( items: list[dict], custom_data: dict, success_url: str, ) -> dict: """Create a Stripe Checkout Session for multiple line items. items: list of {"priceId": "price_xxx", "quantity": 1} """ s = _stripe_client() line_items = [{"price": i["priceId"], "quantity": i.get("quantity", 1)} for i in items] # Determine mode: if any item is recurring, use "subscription". # Otherwise use "payment" for one-time purchases. has_recurring = any(_is_recurring_price(s, i["priceId"]) for i in items) mode = "subscription" if has_recurring else "payment" session = s.checkout.Session.create( mode=mode, line_items=line_items, metadata=custom_data, success_url=success_url + "?session_id={CHECKOUT_SESSION_ID}", cancel_url=success_url.rsplit("/success", 1)[0], automatic_tax={"enabled": True}, tax_id_collection={"enabled": True}, ) return {"checkout_url": session.url} def _mode_for_price(s, price_id: str) -> str: """Determine Checkout Session mode from price type.""" try: price = s.Price.retrieve(price_id) return "subscription" if price.type == "recurring" else "payment" except Exception: # Default to payment if we can't determine return "payment" def _is_recurring_price(s, price_id: str) -> bool: """Check if a Stripe price is recurring (subscription).""" try: price = s.Price.retrieve(price_id) return price.type == "recurring" except Exception: return False def cancel_subscription(provider_subscription_id: str) -> None: """Cancel a Stripe subscription at end of current billing period.""" s = _stripe_client() s.Subscription.modify( provider_subscription_id, cancel_at_period_end=True, ) def get_management_url(provider_subscription_id: str) -> str: """Create a Stripe Billing Portal session and return its URL.""" s = _stripe_client() # Get customer_id from the subscription sub = s.Subscription.retrieve( provider_subscription_id, ) portal = s.billing_portal.Session.create( customer=sub.customer, return_url=f"{config.BASE_URL}/billing/success", ) return portal.url def verify_webhook(payload: bytes, headers) -> bool: """Verify Stripe webhook signature using the Stripe-Signature header.""" if not config.STRIPE_WEBHOOK_SECRET: return True sig_header = headers.get("Stripe-Signature", "") if not sig_header: return False try: stripe_sdk.Webhook.construct_event( payload, sig_header, config.STRIPE_WEBHOOK_SECRET, ) return True except (stripe_sdk.SignatureVerificationError, ValueError): return False def parse_webhook(payload: bytes) -> dict: """Parse a Stripe webhook payload into a normalized event dict. Maps Stripe event types to the shared format used by _handle_webhook_event(): - checkout.session.completed (mode=subscription) → subscription.activated - customer.subscription.created → subscription.activated - customer.subscription.updated → subscription.updated - customer.subscription.deleted → subscription.canceled - invoice.payment_failed → subscription.past_due - checkout.session.completed (mode=payment) → transaction.completed """ raw = json.loads(payload) stripe_type = raw.get("type", "") obj = raw.get("data", {}).get("object", {}) # Extract metadata — Stripe stores custom data in session/subscription metadata metadata = obj.get("metadata") or {} # Common fields customer_id = obj.get("customer", "") user_id = metadata.get("user_id") supplier_id = metadata.get("supplier_id") plan = metadata.get("plan", "") # Map Stripe events to our shared event types if stripe_type == "checkout.session.completed": mode = obj.get("mode", "") if mode == "subscription": subscription_id = obj.get("subscription", "") # Fetch subscription details for period end period_end = None if subscription_id: try: s = _stripe_client() sub = s.Subscription.retrieve( subscription_id, ) # Stripe API 2026-02+ moved period_end to items ts = sub.current_period_end if not ts and sub.get("items", {}).get("data"): ts = sub["items"]["data"][0].get("current_period_end") period_end = _unix_to_iso(ts) except Exception: logger.warning("Failed to fetch subscription %s for period_end", subscription_id) return { "event_type": "subscription.activated", "subscription_id": subscription_id, "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": "active", "current_period_end": period_end, "data": obj, "items": _extract_line_items(obj), "custom_data": metadata, } else: # One-time payment return { "event_type": "transaction.completed", "subscription_id": "", "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": "completed", "current_period_end": None, "data": obj, "items": _extract_line_items(obj), "custom_data": metadata, } elif stripe_type == "customer.subscription.created": # New subscription — map to subscription.activated so the handler creates the DB row status = _map_stripe_status(obj.get("status", "")) return { "event_type": "subscription.activated", "subscription_id": obj.get("id", ""), "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": status, "current_period_end": _get_period_end(obj), "data": obj, "items": _extract_sub_items(obj), "custom_data": metadata, } elif stripe_type == "customer.subscription.updated": status = _map_stripe_status(obj.get("status", "")) return { "event_type": "subscription.updated", "subscription_id": obj.get("id", ""), "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": status, "current_period_end": _get_period_end(obj), "data": obj, "items": _extract_sub_items(obj), "custom_data": metadata, } elif stripe_type == "customer.subscription.deleted": return { "event_type": "subscription.canceled", "subscription_id": obj.get("id", ""), "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": "cancelled", "current_period_end": _get_period_end(obj), "data": obj, "items": _extract_sub_items(obj), "custom_data": metadata, } elif stripe_type == "invoice.payment_failed": sub_id = obj.get("subscription", "") return { "event_type": "subscription.past_due", "subscription_id": sub_id, "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": "past_due", "current_period_end": None, "data": obj, "items": [], "custom_data": metadata, } # Unknown event — return a no-op return { "event_type": "", "subscription_id": "", "customer_id": str(customer_id), "user_id": user_id, "supplier_id": supplier_id, "plan": plan, "status": "", "current_period_end": None, "data": obj, "items": [], "custom_data": metadata, } # ============================================================================= # Helpers # ============================================================================= def _map_stripe_status(stripe_status: str) -> str: """Map Stripe subscription status to our internal status.""" mapping = { "active": "active", "trialing": "on_trial", "past_due": "past_due", "canceled": "cancelled", "unpaid": "past_due", "incomplete": "past_due", "incomplete_expired": "expired", "paused": "paused", } return mapping.get(stripe_status, stripe_status) def _unix_to_iso(ts) -> str | None: """Convert Unix timestamp to ISO string, or None.""" if not ts: return None from datetime import UTC, datetime return datetime.fromtimestamp(int(ts), tz=UTC).strftime("%Y-%m-%dT%H:%M:%S.000000Z") def _get_period_end(obj: dict) -> str | None: """Extract current_period_end from subscription or its first item. Stripe API 2026-02+ moved period fields from subscription to subscription items. """ ts = obj.get("current_period_end") if not ts: items = obj.get("items", {}).get("data", []) if items: ts = items[0].get("current_period_end") return _unix_to_iso(ts) def _extract_line_items(session_obj: dict) -> list[dict]: """Extract line items from a Checkout Session in Paddle-compatible format. Stripe doesn't embed line_items in checkout.session.completed webhooks, so we fetch them via the API. Returns [{"price": {"id": "price_xxx"}}]. """ session_id = session_obj.get("id", "") if not session_id or not session_id.startswith("cs_"): return [] try: s = _stripe_client() line_items = s.checkout.Session.list_line_items(session_id, limit=20) return [ {"price": {"id": item["price"]["id"]}} for item in line_items.get("data", []) if item.get("price", {}).get("id") ] except Exception: logger.warning("Failed to fetch line_items for session %s", session_id) # Fallback: check if line_items were embedded in the payload (e.g. tests) embedded = session_obj.get("line_items", {}).get("data", []) return [ {"price": {"id": item["price"]["id"]}} for item in embedded if item.get("price", {}).get("id") ] def _extract_sub_items(sub_obj: dict) -> list[dict]: """Extract items from a Stripe Subscription object in Paddle-compatible format.""" items = sub_obj.get("items", {}).get("data", []) return [{"price": {"id": item.get("price", {}).get("id", "")}} for item in items]