diff --git a/web/src/padelnomics/scripts/setup_stripe.py b/web/src/padelnomics/scripts/setup_stripe.py new file mode 100644 index 0000000..749d95e --- /dev/null +++ b/web/src/padelnomics/scripts/setup_stripe.py @@ -0,0 +1,247 @@ +""" +Create or sync Stripe products, prices, and webhook endpoint. + +Prerequisites: + - Enable Stripe Tax in your Stripe Dashboard (Settings → Tax) + - Set STRIPE_SECRET_KEY in .env + +Commands: + uv run python -m padelnomics.scripts.setup_stripe # create products + webhook + uv run python -m padelnomics.scripts.setup_stripe --sync # re-populate DB from existing Stripe products +""" + +import logging +import os +import re +import sqlite3 +import sys +from pathlib import Path + +import stripe +from dotenv import load_dotenv + +logger = logging.getLogger(__name__) + +load_dotenv() + +STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY", "") +DATABASE_PATH = os.getenv("DATABASE_PATH", "data/app.db") +BASE_URL = os.getenv("BASE_URL", "http://localhost:5000") + +if not STRIPE_SECRET_KEY: + logging.basicConfig(level=logging.INFO, format="%(levelname)-8s %(message)s") + logger.error("Set STRIPE_SECRET_KEY in .env first") + sys.exit(1) + +stripe.api_key = STRIPE_SECRET_KEY +stripe.max_network_retries = 2 + +# Product definitions — same keys as setup_paddle.py. +# Prices in EUR cents, matching Paddle exactly. +PRODUCTS = [ + # Supplier Growth + { + "key": "supplier_growth", + "name": "Supplier Growth", + "price": 19900, + "currency": "eur", + "interval": "month", + "billing_type": "subscription", + }, + { + "key": "supplier_growth_yearly", + "name": "Supplier Growth (Yearly)", + "price": 179900, + "currency": "eur", + "interval": "year", + "billing_type": "subscription", + }, + # Supplier Pro + { + "key": "supplier_pro", + "name": "Supplier Pro", + "price": 49900, + "currency": "eur", + "interval": "month", + "billing_type": "subscription", + }, + { + "key": "supplier_pro_yearly", + "name": "Supplier Pro (Yearly)", + "price": 449900, + "currency": "eur", + "interval": "year", + "billing_type": "subscription", + }, + # Boost add-ons (subscriptions) + {"key": "boost_logo", "name": "Boost: Logo", "price": 2900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, + {"key": "boost_highlight", "name": "Boost: Highlight", "price": 3900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, + {"key": "boost_verified", "name": "Boost: Verified Badge", "price": 4900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, + {"key": "boost_card_color", "name": "Boost: Custom Card Color", "price": 5900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, + # One-time boosts + {"key": "boost_sticky_week", "name": "Boost: Sticky Top 1 Week", "price": 7900, "currency": "eur", "billing_type": "one_time"}, + {"key": "boost_sticky_month", "name": "Boost: Sticky Top 1 Month", "price": 19900, "currency": "eur", "billing_type": "one_time"}, + # Credit packs + {"key": "credits_25", "name": "Credit Pack 25", "price": 9900, "currency": "eur", "billing_type": "one_time"}, + {"key": "credits_50", "name": "Credit Pack 50", "price": 17900, "currency": "eur", "billing_type": "one_time"}, + {"key": "credits_100", "name": "Credit Pack 100", "price": 32900, "currency": "eur", "billing_type": "one_time"}, + {"key": "credits_250", "name": "Credit Pack 250", "price": 74900, "currency": "eur", "billing_type": "one_time"}, + # PDF product + {"key": "business_plan", "name": "Padel Business Plan (PDF)", "price": 14900, "currency": "eur", "billing_type": "one_time"}, + # Planner subscriptions + {"key": "starter", "name": "Planner Starter", "price": 1900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, + {"key": "pro", "name": "Planner Pro", "price": 4900, "currency": "eur", "interval": "month", "billing_type": "subscription"}, +] + +_PRODUCT_BY_NAME = {p["name"]: p for p in PRODUCTS} + + +def _open_db(): + db_path = DATABASE_PATH + if not Path(db_path).exists(): + logger.error("Database not found at %s. Run migrations first.", db_path) + sys.exit(1) + conn = sqlite3.connect(db_path) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + +def _write_product(conn, key, product_id, price_id, name, price_cents, billing_type): + conn.execute( + """INSERT OR REPLACE INTO payment_products + (provider, key, provider_product_id, provider_price_id, name, price_cents, currency, billing_type) + VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", + ("stripe", key, product_id, price_id, name, price_cents, "EUR", billing_type), + ) + + +def sync(conn): + """Fetch existing Stripe products and re-populate payment_products table.""" + logger.info("Syncing products from Stripe...") + + # Fetch all products (auto-paginated, max 100 per page) + products = stripe.Product.list(limit=100, active=True) + matched = 0 + + for product in products.auto_paging_iter(): + spec = _PRODUCT_BY_NAME.get(product.name) + if not spec: + continue + + # Get the first active price for this product + prices = stripe.Price.list(product=product.id, active=True, limit=1) + if not prices.data: + logger.warning(" SKIP %s: no active prices on %s", spec["key"], product.id) + continue + + price = prices.data[0] + _write_product( + conn, spec["key"], product.id, price.id, + spec["name"], spec["price"], spec["billing_type"], + ) + matched += 1 + logger.info(" %s: %s / %s", spec["key"], product.id, price.id) + + conn.commit() + + if matched == 0: + logger.warning("No matching products found in Stripe. Run without --sync first.") + else: + logger.info("%s/%s products synced to DB", matched, len(PRODUCTS)) + + +def create(conn): + """Create new products and prices in Stripe, write to DB, set up webhook.""" + logger.info("Creating products in Stripe...") + + for spec in PRODUCTS: + product = stripe.Product.create( + name=spec["name"], + tax_code="txcd_10000000", # General — Tangible Goods (Stripe default) + ) + logger.info(" Product: %s -> %s", spec["name"], product.id) + + price_params = { + "product": product.id, + "unit_amount": spec["price"], + "currency": spec["currency"], + "tax_behavior": "exclusive", # Price + tax on top (EU standard) + } + + if spec["billing_type"] == "subscription": + interval = spec.get("interval", "month") + price_params["recurring"] = {"interval": interval} + + price = stripe.Price.create(**price_params) + logger.info(" Price: %s = %s", spec["key"], price.id) + + _write_product( + conn, spec["key"], product.id, price.id, + spec["name"], spec["price"], spec["billing_type"], + ) + + conn.commit() + logger.info("All products written to DB") + + # -- Webhook endpoint ------------------------------------------------------- + + webhook_url = f"{BASE_URL}/billing/webhook/stripe" + enabled_events = [ + "checkout.session.completed", + "customer.subscription.updated", + "customer.subscription.deleted", + "invoice.payment_failed", + ] + + logger.info("Creating webhook endpoint...") + logger.info(" URL: %s", webhook_url) + + endpoint = stripe.WebhookEndpoint.create( + url=webhook_url, + enabled_events=enabled_events, + ) + + webhook_secret = endpoint.secret + logger.info(" ID: %s", endpoint.id) + logger.info(" Secret: %s", webhook_secret) + + env_path = Path(".env") + env_vars = { + "STRIPE_WEBHOOK_SECRET": webhook_secret, + "STRIPE_WEBHOOK_ENDPOINT_ID": endpoint.id, + } + if env_path.exists(): + env_text = env_path.read_text() + for key, value in env_vars.items(): + pattern = rf"^{key}=.*$" + replacement = f"{key}={value}" + if re.search(pattern, env_text, flags=re.MULTILINE): + env_text = re.sub(pattern, replacement, env_text, flags=re.MULTILINE) + else: + env_text = env_text.rstrip("\n") + f"\n{replacement}\n" + env_path.write_text(env_text) + logger.info("STRIPE_WEBHOOK_SECRET and STRIPE_WEBHOOK_ENDPOINT_ID written to .env") + else: + logger.info("Add to .env:") + for key, value in env_vars.items(): + logger.info(" %s=%s", key, value) + + logger.info("Done. Remember to enable Stripe Tax in your Dashboard (Settings > Tax).") + + +def main(): + conn = _open_db() + + try: + if "--sync" in sys.argv: + sync(conn) + else: + create(conn) + finally: + conn.close() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)-8s %(message)s") + main()