diff --git a/web/.copier-answers.yml b/web/.copier-answers.yml index 96f3567..840dfc8 100644 --- a/web/.copier-answers.yml +++ b/web/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier; NEVER EDIT MANUALLY -_commit: v0.3.0 +_commit: v0.4.0 _src_path: git@gitlab.com:deemanone/materia_saas_boilerplate.master.git author_email: hendrik@beanflows.coffee author_name: Hendrik Deeman diff --git a/web/CHANGELOG.md b/web/CHANGELOG.md new file mode 100644 index 0000000..702897c --- /dev/null +++ b/web/CHANGELOG.md @@ -0,0 +1,73 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/). + +## [Unreleased] + +### Changed +- **Role-based access control**: `user_roles` table with `role_required()` decorator replaces password-based admin auth +- **Admin is a real user**: admins authenticate via magic links; `ADMIN_EMAILS` env var auto-grants admin role on login +- **Separated billing entities**: `billing_customers` table holds payment provider identity; `subscriptions` table holds only subscription state +- **Multiple subscriptions per user**: dropped UNIQUE constraint on `subscriptions.user_id`; `upsert_subscription` finds by `provider_subscription_id` + +### Added +- Simple A/B testing with `@ab_test` decorator and optional Umami `data-tag` integration (`UMAMI_SCRIPT_URL` / `UMAMI_WEBSITE_ID` env vars) +- `user_roles` table and `grant_role()` / `revoke_role()` / `ensure_admin_role()` functions +- `billing_customers` table and `upsert_billing_customer()` / `get_billing_customer()` functions +- `role_required(*roles)` decorator in auth +- `is_admin` template context variable +- Migration `0001_roles_and_billing_customers.py` for existing databases + +### Removed +- `ADMIN_PASSWORD` env var and password-based admin login +- `provider_customer_id` column from `subscriptions` table +- `admin/templates/admin/login.html` + +### Changed +- **Provider-agnostic schema**: generic `provider_customer_id` / `provider_subscription_id` columns replace provider-prefixed names (`stripe_customer_id`, `paddle_customer_id`, `lemonsqueezy_customer_id`) — eliminates all Jinja conditionals from schema, SQL helpers, and route code +- **Consolidated `subscription_required` decorator**: single implementation in `auth/routes.py` supporting both plan and status checks, reads from eager-loaded `g.subscription` (zero extra queries) +- **Eager-loaded `g.subscription`**: `load_user` in `app.py` now fetches user + subscription in a single JOIN; available in all routes and templates via `g.subscription` + +### Added +- `transactions` table for recording payment/refund events with idempotent `record_transaction()` helper +- Billing event hook system: `on_billing_event()` decorator and `_fire_hooks()` for domain code to react to subscription changes; errors are logged and never cause webhook 500s + +### Removed +- Duplicate `subscription_required` decorator from `billing/routes.py` (consolidated in `auth/routes.py`) +- `get_user_with_subscription()` from `auth/routes.py` (replaced by eager-loaded `g.subscription`) + +### Changed +- **Email SDK migration**: replaced raw httpx calls with official `resend` SDK in `core.py` + - Added `from_addr` parameter to `send_email()` for multi-address support + - Added `EMAIL_ADDRESSES` dict for named sender addresses (transactional, etc.) +- **Paddle SDK migration**: replaced raw httpx calls with official `paddle-python-sdk` in `billing/routes.py` + - Checkout, manage, cancel routes now use typed SDK methods (`PaddleClient`, `CreateTransaction`) + - Webhook verification uses SDK's `Verifier` instead of hand-rolled HMAC + - Added `PADDLE_ENVIRONMENT` config for sandbox/production toggling + - Added `_paddle_client()` helper factory +- **Dependencies**: `resend` replaces `httpx` for email; `paddle-python-sdk` replaces `httpx` for Paddle billing; `httpx` now only included for LemonSqueezy projects +- Worker `send_email` task handler now passes through `from_addr` + +### Added +- `scripts/setup_paddle.py` — CLI script to create Paddle products/prices programmatically (Paddle projects only) + +### Changed +- **Pico CSS → Tailwind CSS v4** — full design system migration across all templates + - Standalone Tailwind CLI binary (no Node.js) with `make css-build` / `make css-watch` + - Brand theme with component classes (`.btn`, `.card`, `.form-input`, `.table`, `.badge`, `.flash`, etc.) + - Self-hosted Commit Mono font for monospace data display + - Docker multi-stage build: CSS compiled in dedicated stage before Python build + +### Removed +- Pico CSS CDN dependency +- `custom.css` (replaced by Tailwind `input.css` with `@layer components`) +- JetBrains Mono font (replaced by self-hosted Commit Mono) + +### Fixed +- Admin template collision: namespaced admin templates under `admin/` subdirectory to prevent Quart's template loader from resolving auth's `login.html` or dashboard's `index.html` instead of admin's +- Admin user detail: `stripe_customer_id` hardcoded regardless of payment provider — now uses provider-aware Copier conditional (Stripe/Paddle/LemonSqueezy) + +### Added +- Initial project scaffolded from quart_saas_boilerplate diff --git a/web/Dockerfile b/web/Dockerfile index 726d8f6..953875d 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,3 +1,13 @@ +# CSS build stage (Tailwind standalone CLI, no Node.js) +FROM debian:bookworm-slim AS css-build +ADD https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 /usr/local/bin/tailwindcss +RUN chmod +x /usr/local/bin/tailwindcss +WORKDIR /app +COPY src/ ./src/ +RUN tailwindcss -i ./src/beanflows/static/css/input.css \ + -o ./src/beanflows/static/css/output.css --minify + + # Build stage FROM python:3.12-slim AS build COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/ @@ -15,6 +25,7 @@ RUN useradd -m -u 1000 appuser WORKDIR /app RUN mkdir -p /app/data && chown -R appuser:appuser /app COPY --from=build --chown=appuser:appuser /app . +COPY --from=css-build /app/src/beanflows/static/css/output.css ./src/beanflows/static/css/output.css USER appuser ENV PYTHONUNBUFFERED=1 ENV DATABASE_PATH=/app/data/app.db diff --git a/web/docker-compose.prod.yml b/web/docker-compose.prod.yml index d572da9..d2a06df 100644 --- a/web/docker-compose.prod.yml +++ b/web/docker-compose.prod.yml @@ -21,16 +21,16 @@ services: command: replicate -config /etc/litestream.yml volumes: - app-data:/app/data - - ./beanflows/litestream.yml:/etc/litestream.yml:ro + - ./litestream.yml:/etc/litestream.yml:ro # ── Blue slot ───────────────────────────────────────────── blue-app: profiles: ["blue"] build: - context: ./beanflows + context: . restart: unless-stopped - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: @@ -47,10 +47,10 @@ services: blue-worker: profiles: ["blue"] build: - context: ./beanflows + context: . restart: unless-stopped command: python -m beanflows.worker - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: @@ -61,10 +61,10 @@ services: blue-scheduler: profiles: ["blue"] build: - context: ./beanflows + context: . restart: unless-stopped command: python -m beanflows.worker scheduler - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: @@ -77,9 +77,9 @@ services: green-app: profiles: ["green"] build: - context: ./beanflows + context: . restart: unless-stopped - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: @@ -96,10 +96,10 @@ services: green-worker: profiles: ["green"] build: - context: ./beanflows + context: . restart: unless-stopped command: python -m beanflows.worker - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: @@ -110,10 +110,10 @@ services: green-scheduler: profiles: ["green"] build: - context: ./beanflows + context: . restart: unless-stopped command: python -m beanflows.worker scheduler - env_file: ./beanflows/.env + env_file: ./.env environment: - DATABASE_PATH=/app/data/app.db volumes: diff --git a/web/pyproject.toml b/web/pyproject.toml index b5d3eea..3f7a002 100644 --- a/web/pyproject.toml +++ b/web/pyproject.toml @@ -12,8 +12,9 @@ dependencies = [ "aiosqlite>=0.19.0", "duckdb>=1.0.0", "httpx>=0.27.0", + "resend>=2.22.0", "python-dotenv>=1.0.0", - + "paddle-python-sdk>=1.13.0", "itsdangerous>=2.1.0", "jinja2>=3.1.0", "hypercorn>=0.17.0", diff --git a/web/src/beanflows/app.py b/web/src/beanflows/app.py index 1fc0648..b364725 100644 --- a/web/src/beanflows/app.py +++ b/web/src/beanflows/app.py @@ -12,24 +12,24 @@ from .core import close_db, config, get_csrf_token, init_db, setup_request_id def create_app() -> Quart: """Create and configure the Quart application.""" - + # Get package directory for templates pkg_dir = Path(__file__).parent - + app = Quart( __name__, template_folder=str(pkg_dir / "templates"), static_folder=str(pkg_dir / "static"), ) - + app.secret_key = config.SECRET_KEY - + # Session config app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG app.config["SESSION_COOKIE_HTTPONLY"] = True app.config["SESSION_COOKIE_SAMESITE"] = "Lax" app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS - + # Database lifecycle @app.before_serving async def startup(): @@ -41,7 +41,7 @@ def create_app() -> Quart: async def shutdown(): close_analytics_db() await close_db() - + # Security headers @app.after_request async def add_security_headers(response): @@ -51,16 +51,42 @@ def create_app() -> Quart: if not config.DEBUG: response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" return response - - # Load current user before each request + + # Load current user + subscription + roles before each request @app.before_request async def load_user(): g.user = None + g.subscription = None user_id = session.get("user_id") if user_id: - from .auth.routes import get_user_by_id - g.user = await get_user_by_id(user_id) - + from .core import fetch_one as _fetch_one + row = await _fetch_one( + """SELECT u.*, + bc.provider_customer_id, + (SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv, + s.id AS sub_id, s.plan, s.status AS sub_status, + s.provider_subscription_id, s.current_period_end + FROM users u + LEFT JOIN billing_customers bc ON bc.user_id = u.id + LEFT JOIN subscriptions s ON s.id = ( + SELECT id FROM subscriptions + WHERE user_id = u.id + ORDER BY created_at DESC LIMIT 1 + ) + WHERE u.id = ? AND u.deleted_at IS NULL""", + (user_id,), + ) + if row: + g.user = dict(row) + g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else [] + if row["sub_id"]: + g.subscription = { + "id": row["sub_id"], "plan": row["plan"], + "status": row["sub_status"], + "provider_subscription_id": row["provider_subscription_id"], + "current_period_end": row["current_period_end"], + } + # Template context globals @app.context_processor def inject_globals(): @@ -68,10 +94,14 @@ def create_app() -> Quart: return { "config": config, "user": g.get("user"), + "subscription": g.get("subscription"), + "is_admin": "admin" in (g.get("user") or {}).get("roles", []), "now": datetime.utcnow(), "csrf_token": get_csrf_token, + "ab_variant": getattr(g, "ab_variant", None), + "ab_tag": getattr(g, "ab_tag", None), } - + # Health check @app.route("/health") async def health(): @@ -94,7 +124,7 @@ def create_app() -> Quart: result["duckdb"] = "not configured" status_code = 200 if result["status"] == "healthy" else 500 return result, status_code - + # Register blueprints from .admin.routes import bp as admin_bp from .api.routes import bp as api_bp @@ -102,17 +132,17 @@ def create_app() -> Quart: from .billing.routes import bp as billing_bp from .dashboard.routes import bp as dashboard_bp from .public.routes import bp as public_bp - + app.register_blueprint(public_bp) app.register_blueprint(auth_bp) app.register_blueprint(dashboard_bp) app.register_blueprint(billing_bp) app.register_blueprint(api_bp, url_prefix="/api/v1") app.register_blueprint(admin_bp) - + # Request ID tracking setup_request_id(app) - + return app diff --git a/web/src/beanflows/auth/routes.py b/web/src/beanflows/auth/routes.py index 8a6af9a..dc7fc98 100644 --- a/web/src/beanflows/auth/routes.py +++ b/web/src/beanflows/auth/routes.py @@ -71,7 +71,7 @@ async def get_valid_token(token: str) -> dict | None: """Get token if valid and not expired.""" return await fetch_one( """ - SELECT at.*, u.email + SELECT at.*, u.email FROM auth_tokens at JOIN users u ON u.id = at.user_id WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL @@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None: ) -async def get_user_with_subscription(user_id: int) -> dict | None: - """Get user with their active subscription info.""" - return await fetch_one( - """ - SELECT u.*, s.plan, s.status as sub_status, s.current_period_end - FROM users u - LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active' - WHERE u.id = ? AND u.deleted_at IS NULL - """, - (user_id,) - ) - - # ============================================================================= # Decorators # ============================================================================= @@ -116,24 +103,69 @@ def login_required(f): return decorated -def subscription_required(plans: list[str] = None): - """Require active subscription, optionally of specific plan(s).""" +def role_required(*roles): + """Require user to have at least one of the given roles.""" + def decorator(f): + @wraps(f) + async def decorated(*args, **kwargs): + if not g.get("user"): + await flash("Please sign in to continue.", "warning") + return redirect(url_for("auth.login", next=request.path)) + user_roles = g.user.get("roles", []) + if not any(r in user_roles for r in roles): + await flash("You don't have permission to access that page.", "error") + return redirect(url_for("dashboard.index")) + return await f(*args, **kwargs) + return decorated + return decorator + + +async def grant_role(user_id: int, role: str) -> None: + """Grant a role to a user (idempotent).""" + await execute( + "INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)", + (user_id, role), + ) + + +async def revoke_role(user_id: int, role: str) -> None: + """Revoke a role from a user.""" + await execute( + "DELETE FROM user_roles WHERE user_id = ? AND role = ?", + (user_id, role), + ) + + +async def ensure_admin_role(user_id: int, email: str) -> None: + """Grant admin role if email is in ADMIN_EMAILS.""" + if email.lower() in config.ADMIN_EMAILS: + await grant_role(user_id, "admin") + + +def subscription_required( + plans: list[str] = None, + allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"), +): + """Require active subscription, optionally of specific plan(s) and/or statuses. + + Reads from g.subscription (eager-loaded in load_user) — zero extra queries. + """ def decorator(f): @wraps(f) async def decorated(*args, **kwargs): if not g.get("user"): await flash("Please sign in to continue.", "warning") return redirect(url_for("auth.login")) - - user = await get_user_with_subscription(g.user["id"]) - if not user or not user.get("plan"): + + sub = g.get("subscription") + if not sub or sub["status"] not in allowed: await flash("Please subscribe to access this feature.", "warning") return redirect(url_for("billing.pricing")) - - if plans and user["plan"] not in plans: + + if plans and sub["plan"] not in plans: await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning") return redirect(url_for("billing.pricing")) - + return await f(*args, **kwargs) return decorated return decorator @@ -149,33 +181,33 @@ async def login(): """Login page - request magic link.""" if g.get("user"): return redirect(url_for("dashboard.index")) - + if request.method == "POST": form = await request.form email = form.get("email", "").strip().lower() - + if not email or "@" not in email: await flash("Please enter a valid email address.", "error") return redirect(url_for("auth.login")) - + # Get or create user user = await get_user_by_email(email) if not user: user_id = await create_user(email) else: user_id = user["id"] - + # Create magic link token token = secrets.token_urlsafe(32) await create_auth_token(user_id, token) - + # Queue email from ..worker import enqueue await enqueue("send_magic_link", {"email": email, "token": token}) - + await flash("Check your email for the sign-in link!", "success") return redirect(url_for("auth.magic_link_sent", email=email)) - + return await render_template("login.html") @@ -185,39 +217,39 @@ async def signup(): """Signup page - same as login but with different messaging.""" if g.get("user"): return redirect(url_for("dashboard.index")) - + plan = request.args.get("plan", "free") - + if request.method == "POST": form = await request.form email = form.get("email", "").strip().lower() selected_plan = form.get("plan", "free") - + if not email or "@" not in email: await flash("Please enter a valid email address.", "error") return redirect(url_for("auth.signup", plan=selected_plan)) - + # Check if user exists user = await get_user_by_email(email) if user: await flash("Account already exists. Please sign in.", "info") return redirect(url_for("auth.login")) - + # Create user user_id = await create_user(email) - + # Create magic link token token = secrets.token_urlsafe(32) await create_auth_token(user_id, token) - + # Queue emails from ..worker import enqueue await enqueue("send_magic_link", {"email": email, "token": token}) await enqueue("send_welcome", {"email": email}) - + await flash("Check your email to complete signup!", "success") return redirect(url_for("auth.magic_link_sent", email=email)) - + return await render_template("signup.html", plan=plan) @@ -225,29 +257,32 @@ async def signup(): async def verify(): """Verify magic link token.""" token = request.args.get("token") - + if not token: await flash("Invalid or expired link.", "error") return redirect(url_for("auth.login")) - + token_data = await get_valid_token(token) - + if not token_data: await flash("Invalid or expired link. Please request a new one.", "error") return redirect(url_for("auth.login")) - + # Mark token as used await mark_token_used(token_data["id"]) - + # Update last login await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat()) - + # Set session session.permanent = True session["user_id"] = token_data["user_id"] - + + # Auto-grant admin role if email is in ADMIN_EMAILS + await ensure_admin_role(token_data["user_id"], token_data["email"]) + await flash("Successfully signed in!", "success") - + # Redirect to intended page or dashboard next_url = request.args.get("next", url_for("dashboard.index")) return redirect(next_url) @@ -274,18 +309,21 @@ async def dev_login(): """Instant login for development. Only works in DEBUG mode.""" if not config.DEBUG: return "Not available", 404 - + email = request.args.get("email", "dev@localhost") - + user = await get_user_by_email(email) if not user: user_id = await create_user(email) else: user_id = user["id"] - + session.permanent = True session["user_id"] = user_id - + + # Auto-grant admin role if email is in ADMIN_EMAILS + await ensure_admin_role(user_id, email) + await flash(f"Dev login as {email}", "success") return redirect(url_for("dashboard.index")) @@ -296,19 +334,19 @@ async def resend(): """Resend magic link.""" form = await request.form email = form.get("email", "").strip().lower() - + if not email: await flash("Email address required.", "error") return redirect(url_for("auth.login")) - + user = await get_user_by_email(email) if user: token = secrets.token_urlsafe(32) await create_auth_token(user["id"], token) - + from ..worker import enqueue await enqueue("send_magic_link", {"email": email, "token": token}) - + # Always show success (don't reveal if email exists) await flash("If that email is registered, we've sent a new link.", "success") return redirect(url_for("auth.magic_link_sent", email=email)) diff --git a/web/src/beanflows/core.py b/web/src/beanflows/core.py index 2479d73..62e9d1f 100644 --- a/web/src/beanflows/core.py +++ b/web/src/beanflows/core.py @@ -2,16 +2,19 @@ Core infrastructure: database, config, email, and shared utilities. """ import os +import random import secrets import hashlib import hmac + import aiosqlite +import resend import httpx from pathlib import Path from functools import wraps from datetime import datetime, timedelta from contextvars import ContextVar -from quart import request, session, g +from quart import g, make_response, request, session from dotenv import load_dotenv # web/.env is three levels up from web/src/beanflows/core.py @@ -26,27 +29,35 @@ class Config: SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production") BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001") DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" - + DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db") - + MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15")) SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30")) - + PAYMENT_PROVIDER: str = "paddle" - + PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "") PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "") + PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox") PADDLE_PRICES: dict = { "starter": os.getenv("PADDLE_PRICE_STARTER", ""), "pro": os.getenv("PADDLE_PRICE_PRO", ""), } - + + UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "") + UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "") + RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "") EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com") - + + ADMIN_EMAILS: list[str] = [ + e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip() + ] + RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100")) RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60")) - + PLAN_FEATURES: dict = { "free": ["dashboard", "coffee_only", "limited_history"], "starter": ["dashboard", "coffee_only", "full_history", "export", "api"], @@ -74,10 +85,10 @@ async def init_db(path: str = None) -> None: global _db db_path = path or config.DATABASE_PATH Path(db_path).parent.mkdir(parents=True, exist_ok=True) - + _db = await aiosqlite.connect(db_path) _db.row_factory = aiosqlite.Row - + await _db.execute("PRAGMA journal_mode=WAL") await _db.execute("PRAGMA foreign_keys=ON") await _db.execute("PRAGMA busy_timeout=5000") @@ -137,11 +148,11 @@ async def execute_many(sql: str, params_list: list[tuple]) -> None: class transaction: """Async context manager for transactions.""" - + async def __aenter__(self): self.db = await get_db() return self.db - + async def __aexit__(self, exc_type, exc_val, exc_tb): if exc_type is None: await self.db.commit() @@ -153,25 +164,32 @@ class transaction: # Email # ============================================================================= -async def send_email(to: str, subject: str, html: str, text: str = None) -> bool: - """Send email via Resend API.""" +EMAIL_ADDRESSES = { + "transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>", +} + + +async def send_email( + to: str, subject: str, html: str, text: str = None, from_addr: str = None +) -> bool: + """Send email via Resend SDK.""" if not config.RESEND_API_KEY: print(f"[EMAIL] Would send to {to}: {subject}") return True - - async with httpx.AsyncClient() as client: - response = await client.post( - "https://api.resend.com/emails", - headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"}, - json={ - "from": config.EMAIL_FROM, - "to": to, - "subject": subject, - "html": html, - "text": text or html, - }, - ) - return response.status_code == 200 + + resend.api_key = config.RESEND_API_KEY + try: + resend.Emails.send({ + "from": from_addr or config.EMAIL_FROM, + "to": to, + "subject": subject, + "html": html, + "text": text or html, + }) + return True + except Exception as e: + print(f"[EMAIL] Error sending to {to}: {e}") + return False # ============================================================================= # CSRF Protection @@ -214,34 +232,34 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t window = window or config.RATE_LIMIT_WINDOW now = datetime.utcnow() window_start = now - timedelta(seconds=window) - + # Clean old entries and count recent await execute( "DELETE FROM rate_limits WHERE key = ? AND timestamp < ?", (key, window_start.isoformat()) ) - + result = await fetch_one( "SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?", (key, window_start.isoformat()) ) count = result["count"] if result else 0 - + info = { "limit": limit, "remaining": max(0, limit - count - 1), "reset": int((window_start + timedelta(seconds=window)).timestamp()), } - + if count >= limit: return False, info - + # Record this request await execute( "INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)", (key, now.isoformat()) ) - + return True, info @@ -254,13 +272,13 @@ def rate_limit(limit: int = None, window: int = None, key_func=None): key = key_func() else: key = f"ip:{request.remote_addr}" - + allowed, info = await check_rate_limit(key, limit, window) - + if not allowed: response = {"error": "Rate limit exceeded", **info} return response, 429 - + return await f(*args, **kwargs) return decorated return decorator @@ -284,7 +302,7 @@ def setup_request_id(app): rid = request.headers.get("X-Request-ID") or secrets.token_hex(8) request_id_var.set(rid) g.request_id = rid - + @app.after_request async def add_request_id_header(response): response.headers["X-Request-ID"] = get_request_id() @@ -294,13 +312,11 @@ def setup_request_id(app): # Webhook Signature Verification # ============================================================================= - def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool: """Verify HMAC-SHA256 webhook signature.""" expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() return hmac.compare_digest(signature, expected) - # ============================================================================= # Soft Delete Helpers # ============================================================================= @@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int: f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?", (cutoff,) ) + + +# ============================================================================= +# A/B Testing +# ============================================================================= + +def ab_test(experiment: str, variants: tuple = ("control", "treatment")): + """Assign visitor to an A/B test variant via cookie, tag Umami pageviews.""" + def decorator(f): + @wraps(f) + async def wrapper(*args, **kwargs): + cookie_key = f"ab_{experiment}" + assigned = request.cookies.get(cookie_key) + if assigned not in variants: + assigned = random.choice(variants) + + g.ab_variant = assigned + g.ab_tag = f"{experiment}-{assigned}" + + response = await make_response(await f(*args, **kwargs)) + response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60) + return response + return wrapper + return decorator \ No newline at end of file diff --git a/web/src/beanflows/migrations/migrate.py b/web/src/beanflows/migrations/migrate.py index 05aee3d..0c59397 100644 --- a/web/src/beanflows/migrations/migrate.py +++ b/web/src/beanflows/migrations/migrate.py @@ -1,51 +1,95 @@ """ -Simple migration runner. Runs schema.sql against the database. -""" -import sqlite3 -from pathlib import Path -import os -import sys +Sequential migration runner. + +Replays all migrations in order. All databases — fresh and existing — +go through the same path. No schema.sql fast-path. + +- Scans versions/ for NNNN_*.py files and runs unapplied ones in order +- Each migration has an up(conn) function receiving an uncommitted connection +- All pending migrations share a single transaction (batch atomicity) +""" + +import importlib +import os +import re +import sqlite3 +import sys +from pathlib import Path -# Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) from dotenv import load_dotenv load_dotenv() +VERSIONS_DIR = Path(__file__).parent / "versions" +VERSION_RE = re.compile(r"^(\d{4})_.+\.py$") -def migrate(): - """Run migrations.""" - # Get database path from env or default - db_path = os.getenv("DATABASE_PATH", "data/app.db") - - # Ensure directory exists +# Derived from the package path: …/src//migrations/migrate.py +_PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject" + + +def _discover_versions(): + """Return sorted list of version file stems.""" + if not VERSIONS_DIR.is_dir(): + return [] + versions = [] + for f in sorted(VERSIONS_DIR.iterdir()): + if VERSION_RE.match(f.name): + versions.append(f.stem) + return versions + + +def migrate(db_path=None): + if db_path is None: + db_path = os.getenv("DATABASE_PATH", "data/app.db") Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - # Read schema - schema_path = Path(__file__).parent / "schema.sql" - schema = schema_path.read_text() - - # Connect and execute + conn = sqlite3.connect(db_path) - - # Enable WAL mode conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA foreign_keys=ON") - - # Run schema - conn.executescript(schema) + + # Ensure tracking table exists before anything else + conn.execute(""" + CREATE TABLE IF NOT EXISTS _migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + applied_at TEXT NOT NULL DEFAULT (datetime('now')) + ) + """) conn.commit() - - print(f"✓ Migrations complete: {db_path}") - - # Show tables + + versions = _discover_versions() + applied = { + row[0] + for row in conn.execute("SELECT name FROM _migrations").fetchall() + } + pending = [v for v in versions if v not in applied] + + if pending: + for name in pending: + print(f" Applying {name}...") + mod = importlib.import_module( + f"{_PACKAGE}.migrations.versions.{name}" + ) + mod.up(conn) + conn.execute( + "INSERT INTO _migrations (name) VALUES (?)", (name,) + ) + conn.commit() + print(f"✓ Applied {len(pending)} migration(s): {db_path}") + else: + print(f"✓ All migrations already applied: {db_path}") + + # Show tables (excluding internal sqlite/fts tables) cursor = conn.execute( - "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + "SELECT name FROM sqlite_master WHERE type='table'" + " AND name NOT LIKE 'sqlite_%'" + " ORDER BY name" ) tables = [row[0] for row in cursor.fetchall()] print(f" Tables: {', '.join(tables)}") - + conn.close() diff --git a/web/src/beanflows/migrations/schema.sql b/web/src/beanflows/migrations/schema.sql index 2305e2c..f3c44f8 100644 --- a/web/src/beanflows/migrations/schema.sql +++ b/web/src/beanflows/migrations/schema.sql @@ -1,6 +1,13 @@ -- BeanFlows Database Schema -- Run with: python -m beanflows.migrations.migrate +-- Migration tracking +CREATE TABLE IF NOT EXISTS _migrations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT UNIQUE NOT NULL, + applied_at TEXT NOT NULL DEFAULT (datetime('now')) +); + -- Users CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, @@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens ( CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token); CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id); +-- User Roles +CREATE TABLE IF NOT EXISTS user_roles ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE, + role TEXT NOT NULL, + granted_at TEXT NOT NULL DEFAULT (datetime('now')), + UNIQUE(user_id, role) +); +CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id); +CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role); + +-- Billing Customers (payment provider identity, separate from subscriptions) +CREATE TABLE IF NOT EXISTS billing_customers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL UNIQUE REFERENCES users(id), + provider_customer_id TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) +); +CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id); +CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id); + -- Subscriptions CREATE TABLE IF NOT EXISTS subscriptions ( id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL UNIQUE REFERENCES users(id), + user_id INTEGER NOT NULL REFERENCES users(id), plan TEXT NOT NULL DEFAULT 'free', status TEXT NOT NULL DEFAULT 'free', - - paddle_customer_id TEXT, - paddle_subscription_id TEXT, - + provider_subscription_id TEXT, current_period_end TEXT, created_at TEXT NOT NULL, updated_at TEXT ); CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id); +CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id); -CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id); +-- Transactions +CREATE TABLE IF NOT EXISTS transactions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id), + subscription_id INTEGER REFERENCES subscriptions(id), + provider_transaction_id TEXT UNIQUE, + type TEXT NOT NULL DEFAULT 'payment', + amount_cents INTEGER, + currency TEXT DEFAULT 'USD', + status TEXT NOT NULL DEFAULT 'pending', + metadata TEXT, + created_at TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id); +CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id); -- API Keys CREATE TABLE IF NOT EXISTS api_keys ( @@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks ( ); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at); + +-- Items (example domain entity - replace with your domain) +CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL REFERENCES users(id), + name TEXT NOT NULL, + data TEXT, + created_at TEXT NOT NULL, + updated_at TEXT, + deleted_at TEXT +); + +CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id); +CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at); + +-- Full-text search for items (optional) +CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5( + name, + data, + content='items', + content_rowid='id' +); + +-- FTS triggers +CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN + INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data); +END; + +CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN + INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data); +END; + +CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN + INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data); + INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data); +END; \ No newline at end of file diff --git a/web/src/beanflows/scripts/__init__.py b/web/src/beanflows/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/web/src/beanflows/scripts/setup_paddle.py b/web/src/beanflows/scripts/setup_paddle.py new file mode 100644 index 0000000..9463d07 --- /dev/null +++ b/web/src/beanflows/scripts/setup_paddle.py @@ -0,0 +1,92 @@ + +""" +Create Paddle products and prices for BeanFlows. + +Run once per environment (sandbox, then production). +Prints resulting price IDs as a .env snippet. + +Usage: + uv run python -m beanflows.scripts.setup_paddle +""" + +import os +import sys + +from dotenv import load_dotenv +from paddle_billing import Client as PaddleClient +from paddle_billing import Environment, Options +from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory +from paddle_billing.Resources.Prices.Operations import CreatePrice +from paddle_billing.Resources.Products.Operations import CreateProduct + +load_dotenv() + +PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "") +PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox") + +if not PADDLE_API_KEY: + print("ERROR: Set PADDLE_API_KEY in .env first") + sys.exit(1) + + +PRODUCTS = [ + # Subscriptions + { + "name": "Starter", + "env_key": "PADDLE_PRICE_STARTER", + "price": 900, + "currency": CurrencyCode.USD, + "interval": "month", + "type": "subscription", + }, + { + "name": "Pro", + "env_key": "PADDLE_PRICE_PRO", + "price": 2900, + "currency": CurrencyCode.USD, + "interval": "month", + "type": "subscription", + }, +] + + +def main(): + env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION + paddle = PaddleClient(PADDLE_API_KEY, options=Options(env)) + + print(f"Creating products in {PADDLE_ENVIRONMENT}...\n") + + env_lines = [] + + for spec in PRODUCTS: + # Create product + product = paddle.products.create(CreateProduct( + name=spec["name"], + tax_category=TaxCategory.Standard, + )) + print(f" Product: {spec['name']} -> {product.id}") + + # Create price + price_kwargs = { + "description": spec["name"], + "product_id": product.id, + "unit_price": Money(str(spec["price"]), spec["currency"]), + } + + if spec["type"] == "subscription": + from paddle_billing.Entities.Shared import TimePeriod + price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1) + + price = paddle.prices.create(CreatePrice(**price_kwargs)) + print(f" Price: {spec['env_key']} = {price.id}") + + env_lines.append(f"{spec['env_key']}={price.id}") + + print("\n# --- .env snippet ---") + for line in env_lines: + print(line) + + +if __name__ == "__main__": + main() + diff --git a/web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2 b/web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2 new file mode 100644 index 0000000..011d67c Binary files /dev/null and b/web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2 differ diff --git a/web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2 b/web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2 new file mode 100644 index 0000000..af84f79 Binary files /dev/null and b/web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2 differ diff --git a/web/src/beanflows/static/fonts/CommitMono-LICENSE.txt b/web/src/beanflows/static/fonts/CommitMono-LICENSE.txt new file mode 100644 index 0000000..96d39dd --- /dev/null +++ b/web/src/beanflows/static/fonts/CommitMono-LICENSE.txt @@ -0,0 +1,90 @@ +This Font Software is licensed under the SIL Open Font License, Version 1.1. +This license is copied below, and is also available with a FAQ at: +http://scripts.sil.org/OFL + +----------------------------------------------------------- +SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 +----------------------------------------------------------- + +PREAMBLE +The goals of the Open Font License (OFL) are to stimulate worldwide +development of collaborative font projects, to support the font creation +efforts of academic and linguistic communities, and to provide a free and +open framework in which fonts may be shared and improved in partnership +with others. + +The OFL allows the licensed fonts to be used, studied, modified and +redistributed freely as long as they are not sold by themselves. The +fonts, including any derivative works, can be bundled, embedded, +redistributed and/or sold with any software provided that any reserved +names are not used by derivative works. The fonts and derivatives, +however, cannot be released under any other type of license. The +requirement for fonts to remain under this license does not apply +to any document created using the fonts or their derivatives. + +DEFINITIONS +"Font Software" refers to the set of files released by the Copyright +Holder(s) under this license and clearly marked as such. This may +include source files, build scripts and documentation. + +"Reserved Font Name" refers to any names specified as such after the +copyright statement(s). + +"Original Version" refers to the collection of Font Software components as +distributed by the Copyright Holder(s). + +"Modified Version" refers to any derivative made by adding to, deleting, +or substituting -- in part or in whole -- any of the components of the +Original Version, by changing formats or by porting the Font Software to a +new environment. + +"Author" refers to any designer, engineer, programmer, technical +writer or other person who contributed to the Font Software. + +PERMISSION & CONDITIONS +Permission is hereby granted, free of charge, to any person obtaining +a copy of the Font Software, to use, study, copy, merge, embed, modify, +redistribute, and sell modified and unmodified copies of the Font +Software, subject to the following conditions: + +1) Neither the Font Software nor any of its individual components, +in Original or Modified Versions, may be sold by itself. + +2) Original or Modified Versions of the Font Software may be bundled, +redistributed and/or sold with any software, provided that each copy +contains the above copyright notice and this license. These can be +included either as stand-alone text files, human-readable headers or +in the appropriate machine-readable metadata fields within text or +binary files as long as those fields can be easily viewed by the user. + +3) No Modified Version of the Font Software may use the Reserved Font +Name(s) unless explicit written permission is granted by the corresponding +Copyright Holder. This restriction only applies to the primary font name as +presented to the users. + +4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font +Software shall not be used to promote, endorse or advertise any +Modified Version, except to acknowledge the contribution(s) of the +Copyright Holder(s) and the Author(s) or with their explicit written +permission. + +5) The Font Software, modified or unmodified, in part or in whole, +must be distributed entirely under this license, and must not be +distributed under any other license. The requirement for fonts to +remain under this license does not apply to any document created +using the Font Software. + +TERMINATION +This license becomes null and void if any of the above conditions are +not met. + +DISCLAIMER +THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT +OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE +COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL +DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM +OTHER DEALINGS IN THE FONT SOFTWARE. \ No newline at end of file diff --git a/web/src/beanflows/worker.py b/web/src/beanflows/worker.py index ca42a93..bf8d066 100644 --- a/web/src/beanflows/worker.py +++ b/web/src/beanflows/worker.py @@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email HANDLERS: dict[str, callable] = {} +def _email_wrap(body: str) -> str: + """Wrap email body in a branded layout with inline CSS.""" + return f"""\ + + + + + + +
+ + + + + + + +
+ {config.APP_NAME} +
+ {body} +
+ © {config.APP_NAME} · You received this because you have an account. +
+
+ +""" + + +def _email_button(url: str, label: str) -> str: + """Render a branded CTA button for email.""" + return ( + f'' + f'
' + f'' + f'{label}
' + ) + + def task(name: str): """Decorator to register a task handler.""" def decorator(f): @@ -46,7 +86,7 @@ async def get_pending_tasks(limit: int = 10) -> list[dict]: now = datetime.utcnow().isoformat() return await fetch_all( """ - SELECT * FROM tasks + SELECT * FROM tasks WHERE status = 'pending' AND run_at <= ? ORDER BY run_at ASC LIMIT ? @@ -66,15 +106,15 @@ async def mark_complete(task_id: int) -> None: async def mark_failed(task_id: int, error: str, retries: int) -> None: """Mark task as failed, schedule retry if attempts remain.""" max_retries = 3 - + if retries < max_retries: # Exponential backoff: 1min, 5min, 25min delay = timedelta(minutes=5 ** retries) run_at = datetime.utcnow() + delay - + await execute( """ - UPDATE tasks + UPDATE tasks SET status = 'pending', error = ?, retries = ?, run_at = ? WHERE id = ? """, @@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None: subject=payload["subject"], html=payload["html"], text=payload.get("text"), + from_addr=payload.get("from_addr"), ) @@ -106,35 +147,37 @@ async def handle_send_email(payload: dict) -> None: async def handle_send_magic_link(payload: dict) -> None: """Send magic link email.""" link = f"{config.BASE_URL}/auth/verify?token={payload['token']}" - - html = f""" -

Sign in to {config.APP_NAME}

-

Click the link below to sign in:

-

{link}

-

This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.

-

If you didn't request this, you can safely ignore this email.

- """ - + + body = ( + f'

Sign in to {config.APP_NAME}

' + f"

Click the button below to sign in. This link expires in " + f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.

" + f"{_email_button(link, 'Sign In')}" + f'

If the button doesn\'t work, copy and paste this URL into your browser:

' + f'

{link}

' + f'

If you didn\'t request this, you can safely ignore this email.

' + ) + await send_email( to=payload["email"], subject=f"Sign in to {config.APP_NAME}", - html=html, + html=_email_wrap(body), ) @task("send_welcome") async def handle_send_welcome(payload: dict) -> None: """Send welcome email to new user.""" - html = f""" -

Welcome to {config.APP_NAME}!

-

Thanks for signing up. We're excited to have you.

-

Go to your dashboard

- """ - + body = ( + f'

Welcome to {config.APP_NAME}!

' + f"

Thanks for signing up. We're excited to have you.

" + f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}' + ) + await send_email( to=payload["email"], subject=f"Welcome to {config.APP_NAME}", - html=html, + html=_email_wrap(body), ) @@ -173,12 +216,12 @@ async def process_task(task: dict) -> None: task_name = task["task_name"] task_id = task["id"] retries = task.get("retries", 0) - + handler = HANDLERS.get(task_name) if not handler: await mark_failed(task_id, f"Unknown task: {task_name}", retries) return - + try: payload = json.loads(task["payload"]) if task["payload"] else {} await handler(payload) @@ -194,17 +237,17 @@ async def run_worker(poll_interval: float = 1.0) -> None: """Main worker loop.""" print("[WORKER] Starting...") await init_db() - + while True: try: tasks = await get_pending_tasks(limit=10) - + for task in tasks: await process_task(task) - + if not tasks: await asyncio.sleep(poll_interval) - + except Exception as e: print(f"[WORKER] Error: {e}") await asyncio.sleep(poll_interval * 5) @@ -214,16 +257,16 @@ async def run_scheduler() -> None: """Schedule periodic cleanup tasks.""" print("[SCHEDULER] Starting...") await init_db() - + while True: try: # Schedule cleanup tasks every hour await enqueue("cleanup_expired_tokens") await enqueue("cleanup_rate_limits") await enqueue("cleanup_old_tasks") - + await asyncio.sleep(3600) # 1 hour - + except Exception as e: print(f"[SCHEDULER] Error: {e}") await asyncio.sleep(60) @@ -231,8 +274,8 @@ async def run_scheduler() -> None: if __name__ == "__main__": import sys - + if len(sys.argv) > 1 and sys.argv[1] == "scheduler": asyncio.run(run_scheduler()) else: - asyncio.run(run_worker()) + asyncio.run(run_worker()) \ No newline at end of file diff --git a/web/tests/conftest.py b/web/tests/conftest.py index 37a7f8c..c9add7e 100644 --- a/web/tests/conftest.py +++ b/web/tests/conftest.py @@ -10,9 +10,11 @@ from unittest.mock import AsyncMock, patch import aiosqlite import pytest -from beanflows import analytics, core + +from beanflows import core from beanflows.app import create_app + SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql" @@ -44,9 +46,7 @@ async def db(): async def app(db): """Quart app with DB already initialized (init_db/close_db patched to no-op).""" with patch.object(core, "init_db", new_callable=AsyncMock), \ - patch.object(core, "close_db", new_callable=AsyncMock), \ - patch.object(analytics, "open_analytics_db"), \ - patch.object(analytics, "close_analytics_db"): + patch.object(core, "close_db", new_callable=AsyncMock): application = create_app() application.config["TESTING"] = True yield application @@ -92,22 +92,17 @@ def create_subscription(db): user_id: int, plan: str = "pro", status: str = "active", - - paddle_customer_id: str = "ctm_test123", - paddle_subscription_id: str = "sub_test456", - + provider_subscription_id: str = "sub_test456", current_period_end: str = "2025-03-01T00:00:00Z", ) -> int: now = datetime.utcnow().isoformat() async with db.execute( - """INSERT INTO subscriptions - (user_id, plan, status, paddle_customer_id, - paddle_subscription_id, current_period_end, created_at, updated_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", - (user_id, plan, status, paddle_customer_id, paddle_subscription_id, + (user_id, plan, status, + provider_subscription_id, current_period_end, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + (user_id, plan, status, provider_subscription_id, current_period_end, now, now), - ) as cursor: sub_id = cursor.lastrowid await db.commit() @@ -115,6 +110,48 @@ def create_subscription(db): return _create +# ── Billing Customers ─────────────────────────────────────── + +@pytest.fixture +def create_billing_customer(db): + """Factory: create a billing_customers row for a user.""" + async def _create(user_id: int, provider_customer_id: str = "cust_test123") -> int: + async with db.execute( + """INSERT INTO billing_customers (user_id, provider_customer_id) + VALUES (?, ?) + ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""", + (user_id, provider_customer_id), + ) as cursor: + row_id = cursor.lastrowid + await db.commit() + return row_id + return _create + + +# ── Roles ─────────────────────────────────────────────────── + +@pytest.fixture +def grant_role(db): + """Factory: grant a role to a user.""" + async def _grant(user_id: int, role: str) -> None: + await db.execute( + "INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)", + (user_id, role), + ) + await db.commit() + return _grant + + +@pytest.fixture +async def admin_client(app, test_user, grant_role): + """Test client with admin role and session['user_id'] pre-set.""" + await grant_role(test_user["id"], "admin") + async with app.test_client() as c: + async with c.session_transaction() as sess: + sess["user_id"] = test_user["id"] + yield c + + # ── Config ─────────────────────────────────────────────────── @pytest.fixture(autouse=True) @@ -127,6 +164,7 @@ def patch_config(): "PADDLE_API_KEY": "test_api_key_123", "PADDLE_WEBHOOK_SECRET": "whsec_test_secret", + "PADDLE_ENVIRONMENT": "sandbox", "PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"}, "BASE_URL": "http://localhost:5000", @@ -147,6 +185,32 @@ def patch_config(): # ── Webhook helpers ────────────────────────────────────────── +@pytest.fixture(autouse=True) +def mock_paddle_verifier(monkeypatch): + """Mock Paddle's webhook Verifier to accept test payloads.""" + def mock_verify(self, payload, secret, signature): + if not signature or signature == "invalid_signature": + raise ValueError("Invalid signature") + + monkeypatch.setattr( + "paddle_billing.Notifications.Verifier.verify", + mock_verify, + ) + + +@pytest.fixture +def mock_paddle_client(monkeypatch): + """Mock _paddle_client() to return a fake PaddleClient.""" + from unittest.mock import MagicMock + + mock_client = MagicMock() + monkeypatch.setattr( + "beanflows.billing.routes._paddle_client", + lambda: mock_client, + ) + return mock_client + + def make_webhook_payload( event_type: str, subscription_id: str = "sub_test456", @@ -172,76 +236,8 @@ def make_webhook_payload( } -def sign_payload(payload_bytes: bytes, secret: str = "whsec_test_secret") -> str: - """Compute HMAC-SHA256 signature for a webhook payload.""" - return hmac.new(secret.encode(), payload_bytes, hashlib.sha256).hexdigest() - - -# ── Analytics mock data ────────────────────────────────────── - -MOCK_TIME_SERIES = [ - {"market_year": 2018, "Production": 165000, "Exports": 115000, "Imports": 105000, - "Ending_Stocks": 33000, "Total_Distribution": 160000}, - {"market_year": 2019, "Production": 168000, "Exports": 118000, "Imports": 108000, - "Ending_Stocks": 34000, "Total_Distribution": 163000}, - {"market_year": 2020, "Production": 170000, "Exports": 120000, "Imports": 110000, - "Ending_Stocks": 35000, "Total_Distribution": 165000}, - {"market_year": 2021, "Production": 175000, "Exports": 125000, "Imports": 115000, - "Ending_Stocks": 36000, "Total_Distribution": 170000}, - {"market_year": 2022, "Production": 172000, "Exports": 122000, "Imports": 112000, - "Ending_Stocks": 34000, "Total_Distribution": 168000}, -] - -MOCK_TOP_COUNTRIES = [ - {"country_name": "Brazil", "country_code": "BR", "market_year": 2022, "Production": 65000}, - {"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, "Production": 30000}, - {"country_name": "Colombia", "country_code": "CO", "market_year": 2022, "Production": 14000}, -] - -MOCK_STU_TREND = [ - {"market_year": 2020, "Stock_to_Use_Ratio_pct": 21.2}, - {"market_year": 2021, "Stock_to_Use_Ratio_pct": 21.1}, - {"market_year": 2022, "Stock_to_Use_Ratio_pct": 20.2}, -] - -MOCK_BALANCE = [ - {"market_year": 2020, "Production": 170000, "Total_Distribution": 165000, "Supply_Demand_Balance": 5000}, - {"market_year": 2021, "Production": 175000, "Total_Distribution": 170000, "Supply_Demand_Balance": 5000}, - {"market_year": 2022, "Production": 172000, "Total_Distribution": 168000, "Supply_Demand_Balance": 4000}, -] - -MOCK_YOY = [ - {"country_name": "Brazil", "country_code": "BR", "market_year": 2022, - "Production": 65000, "Production_YoY_pct": -3.5}, - {"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, - "Production": 30000, "Production_YoY_pct": 2.1}, -] - -MOCK_COMMODITIES = [ - {"commodity_code": 711100, "commodity_name": "Coffee, Green"}, - {"commodity_code": 222000, "commodity_name": "Soybeans"}, -] - - -@pytest.fixture -def mock_analytics(): - """Patch all analytics query functions with mock data.""" - with patch.object(analytics, "get_global_time_series", new_callable=AsyncMock, - return_value=MOCK_TIME_SERIES), \ - patch.object(analytics, "get_top_countries", new_callable=AsyncMock, - return_value=MOCK_TOP_COUNTRIES), \ - patch.object(analytics, "get_stock_to_use_trend", new_callable=AsyncMock, - return_value=MOCK_STU_TREND), \ - patch.object(analytics, "get_supply_demand_balance", new_callable=AsyncMock, - return_value=MOCK_BALANCE), \ - patch.object(analytics, "get_production_yoy_by_country", new_callable=AsyncMock, - return_value=MOCK_YOY), \ - patch.object(analytics, "get_country_comparison", new_callable=AsyncMock, - return_value=[]), \ - patch.object(analytics, "get_available_commodities", new_callable=AsyncMock, - return_value=MOCK_COMMODITIES), \ - patch.object(analytics, "fetch_analytics", new_callable=AsyncMock, - return_value=[{"result": 1}]): - yield +def sign_payload(payload_bytes: bytes) -> str: + """Return a dummy signature for Paddle webhook tests (Verifier is mocked).""" + return "ts=1234567890;h1=dummy_signature" diff --git a/web/tests/test_billing_helpers.py b/web/tests/test_billing_helpers.py index 8117396..f7ac362 100644 --- a/web/tests/test_billing_helpers.py +++ b/web/tests/test_billing_helpers.py @@ -9,10 +9,13 @@ from hypothesis import strategies as st from beanflows.billing.routes import ( can_access_feature, + get_billing_customer, get_subscription, get_subscription_by_provider_id, is_within_limits, + record_transaction, update_subscription_status, + upsert_billing_customer, upsert_subscription, ) from beanflows.core import config @@ -45,7 +48,6 @@ class TestUpsertSubscription: user_id=test_user["id"], plan="pro", status="active", - provider_customer_id="cust_abc", provider_subscription_id="sub_xyz", current_period_end="2025-06-01T00:00:00Z", ) @@ -53,39 +55,53 @@ class TestUpsertSubscription: row = await get_subscription(test_user["id"]) assert row["plan"] == "pro" assert row["status"] == "active" - - assert row["paddle_customer_id"] == "cust_abc" - assert row["paddle_subscription_id"] == "sub_xyz" - + assert row["provider_subscription_id"] == "sub_xyz" assert row["current_period_end"] == "2025-06-01T00:00:00Z" - async def test_update_existing_subscription(self, db, test_user, create_subscription): - original_id = await create_subscription( - test_user["id"], plan="starter", status="active", - - paddle_subscription_id="sub_old", - + async def test_update_existing_by_provider_subscription_id(self, db, test_user): + """upsert finds existing by provider_subscription_id, not user_id.""" + await upsert_subscription( + user_id=test_user["id"], + plan="starter", + status="active", + provider_subscription_id="sub_same", ) returned_id = await upsert_subscription( user_id=test_user["id"], plan="pro", status="active", - provider_customer_id="cust_new", - provider_subscription_id="sub_new", + provider_subscription_id="sub_same", ) - assert returned_id == original_id row = await get_subscription(test_user["id"]) assert row["plan"] == "pro" + assert row["provider_subscription_id"] == "sub_same" - assert row["paddle_subscription_id"] == "sub_new" - + async def test_different_provider_id_creates_new(self, db, test_user): + """Different provider_subscription_id creates a new row (multi-sub support).""" + await upsert_subscription( + user_id=test_user["id"], + plan="starter", + status="active", + provider_subscription_id="sub_first", + ) + await upsert_subscription( + user_id=test_user["id"], + plan="pro", + status="active", + provider_subscription_id="sub_second", + ) + from beanflows.core import fetch_all + rows = await fetch_all( + "SELECT * FROM subscriptions WHERE user_id = ? ORDER BY created_at", + (test_user["id"],), + ) + assert len(rows) == 2 async def test_upsert_with_none_period_end(self, db, test_user): await upsert_subscription( user_id=test_user["id"], plan="pro", status="active", - provider_customer_id="cust_1", provider_subscription_id="sub_1", current_period_end=None, ) @@ -93,6 +109,28 @@ class TestUpsertSubscription: assert row["current_period_end"] is None +# ════════════════════════════════════════════════════════════ +# upsert_billing_customer / get_billing_customer +# ════════════════════════════════════════════════════════════ + +class TestUpsertBillingCustomer: + async def test_creates_billing_customer(self, db, test_user): + await upsert_billing_customer(test_user["id"], "cust_abc") + row = await get_billing_customer(test_user["id"]) + assert row is not None + assert row["provider_customer_id"] == "cust_abc" + + async def test_updates_existing_customer(self, db, test_user): + await upsert_billing_customer(test_user["id"], "cust_old") + await upsert_billing_customer(test_user["id"], "cust_new") + row = await get_billing_customer(test_user["id"]) + assert row["provider_customer_id"] == "cust_new" + + async def test_get_returns_none_for_unknown_user(self, db): + row = await get_billing_customer(99999) + assert row is None + + # ════════════════════════════════════════════════════════════ # get_subscription_by_provider_id # ════════════════════════════════════════════════════════════ @@ -102,10 +140,8 @@ class TestGetSubscriptionByProviderId: result = await get_subscription_by_provider_id("nonexistent") assert result is None - - async def test_finds_by_paddle_subscription_id(self, db, test_user, create_subscription): - await create_subscription(test_user["id"], paddle_subscription_id="sub_findme") - + async def test_finds_by_provider_subscription_id(self, db, test_user, create_subscription): + await create_subscription(test_user["id"], provider_subscription_id="sub_findme") result = await get_subscription_by_provider_id("sub_findme") assert result is not None assert result["user_id"] == test_user["id"] @@ -117,18 +153,14 @@ class TestGetSubscriptionByProviderId: class TestUpdateSubscriptionStatus: async def test_updates_status(self, db, test_user, create_subscription): - - await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_upd") - + await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_upd") await update_subscription_status("sub_upd", status="cancelled") row = await get_subscription(test_user["id"]) assert row["status"] == "cancelled" assert row["updated_at"] is not None async def test_updates_extra_fields(self, db, test_user, create_subscription): - - await create_subscription(test_user["id"], paddle_subscription_id="sub_extra") - + await create_subscription(test_user["id"], provider_subscription_id="sub_extra") await update_subscription_status( "sub_extra", status="active", @@ -141,9 +173,7 @@ class TestUpdateSubscriptionStatus: assert row["current_period_end"] == "2026-01-01T00:00:00Z" async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription): - - await create_subscription(test_user["id"], paddle_subscription_id="sub_known", status="active") - + await create_subscription(test_user["id"], provider_subscription_id="sub_known", status="active") await update_subscription_status("sub_unknown", status="expired") row = await get_subscription(test_user["id"]) assert row["status"] == "active" # unchanged @@ -155,22 +185,22 @@ class TestUpdateSubscriptionStatus: class TestCanAccessFeature: async def test_no_subscription_gets_free_features(self, db, test_user): - assert await can_access_feature(test_user["id"], "dashboard") is True + assert await can_access_feature(test_user["id"], "basic") is True assert await can_access_feature(test_user["id"], "export") is False assert await can_access_feature(test_user["id"], "api") is False async def test_active_pro_gets_all_features(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="active") - assert await can_access_feature(test_user["id"], "dashboard") is True + assert await can_access_feature(test_user["id"], "basic") is True assert await can_access_feature(test_user["id"], "export") is True assert await can_access_feature(test_user["id"], "api") is True assert await can_access_feature(test_user["id"], "priority_support") is True async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="starter", status="active") - assert await can_access_feature(test_user["id"], "dashboard") is True + assert await can_access_feature(test_user["id"], "basic") is True assert await can_access_feature(test_user["id"], "export") is True - assert await can_access_feature(test_user["id"], "all_commodities") is False + assert await can_access_feature(test_user["id"], "api") is False async def test_cancelled_still_has_features(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="cancelled") @@ -183,7 +213,7 @@ class TestCanAccessFeature: async def test_expired_falls_back_to_free(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="expired") assert await can_access_feature(test_user["id"], "api") is False - assert await can_access_feature(test_user["id"], "dashboard") is True + assert await can_access_feature(test_user["id"], "basic") is True async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="past_due") @@ -203,30 +233,28 @@ class TestCanAccessFeature: # ════════════════════════════════════════════════════════════ class TestIsWithinLimits: - async def test_free_user_no_api_calls(self, db, test_user): - assert await is_within_limits(test_user["id"], "api_calls", 0) is False + async def test_free_user_within_limits(self, db, test_user): + assert await is_within_limits(test_user["id"], "items", 50) is True - async def test_free_user_commodity_limit(self, db, test_user): - assert await is_within_limits(test_user["id"], "commodities", 0) is True - assert await is_within_limits(test_user["id"], "commodities", 1) is False + async def test_free_user_at_limit(self, db, test_user): + assert await is_within_limits(test_user["id"], "items", 100) is False - async def test_free_user_history_limit(self, db, test_user): - assert await is_within_limits(test_user["id"], "history_years", 4) is True - assert await is_within_limits(test_user["id"], "history_years", 5) is False + async def test_free_user_over_limit(self, db, test_user): + assert await is_within_limits(test_user["id"], "items", 150) is False async def test_pro_unlimited(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="active") - assert await is_within_limits(test_user["id"], "commodities", 999999) is True + assert await is_within_limits(test_user["id"], "items", 999999) is True assert await is_within_limits(test_user["id"], "api_calls", 999999) is True async def test_starter_limits(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="starter", status="active") - assert await is_within_limits(test_user["id"], "api_calls", 9999) is True - assert await is_within_limits(test_user["id"], "api_calls", 10000) is False + assert await is_within_limits(test_user["id"], "items", 999) is True + assert await is_within_limits(test_user["id"], "items", 1000) is False async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription): await create_subscription(test_user["id"], plan="pro", status="expired") - assert await is_within_limits(test_user["id"], "api_calls", 0) is False + assert await is_within_limits(test_user["id"], "items", 100) is False async def test_unknown_resource_returns_false(self, db, test_user): assert await is_within_limits(test_user["id"], "unicorns", 0) is False @@ -238,7 +266,7 @@ class TestIsWithinLimits: # ════════════════════════════════════════════════════════════ STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"] -FEATURES = ["dashboard", "export", "api", "priority_support"] +FEATURES = ["basic", "export", "api", "priority_support"] ACTIVE_STATUSES = {"active", "on_trial", "cancelled"} @@ -282,9 +310,9 @@ async def test_plan_feature_matrix(db, test_user, create_subscription, plan, fea @pytest.mark.parametrize("plan", PLANS) @pytest.mark.parametrize("resource,at_limit", [ - ("commodities", 1), - ("commodities", 65), - ("api_calls", 0), + ("items", 100), + ("items", 1000), + ("api_calls", 1000), ("api_calls", 10000), ]) async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit): @@ -307,11 +335,11 @@ async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resou # ════════════════════════════════════════════════════════════ class TestLimitsHypothesis: - @given(count=st.integers(min_value=0, max_value=100)) + @given(count=st.integers(min_value=0, max_value=10000)) @h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture]) - async def test_free_limit_boundary_commodities(self, db, test_user, count): - result = await is_within_limits(test_user["id"], "commodities", count) - assert result == (count < 1) + async def test_free_limit_boundary_items(self, db, test_user, count): + result = await is_within_limits(test_user["id"], "items", count) + assert result == (count < 100) @given(count=st.integers(min_value=0, max_value=100000)) @h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture]) @@ -319,7 +347,56 @@ class TestLimitsHypothesis: # Use upsert to avoid duplicate inserts across Hypothesis examples await upsert_subscription( user_id=test_user["id"], plan="pro", status="active", - provider_customer_id="cust_hyp", provider_subscription_id="sub_hyp", + provider_subscription_id="sub_hyp", ) - result = await is_within_limits(test_user["id"], "commodities", count) + result = await is_within_limits(test_user["id"], "items", count) assert result is True + + +# ════════════════════════════════════════════════════════════ +# record_transaction +# ════════════════════════════════════════════════════════════ + +class TestRecordTransaction: + async def test_inserts_transaction(self, db, test_user): + txn_id = await record_transaction( + user_id=test_user["id"], + provider_transaction_id="txn_abc123", + type="payment", + amount_cents=2999, + currency="EUR", + status="completed", + ) + assert txn_id is not None and txn_id > 0 + + from beanflows.core import fetch_one + row = await fetch_one( + "SELECT * FROM transactions WHERE provider_transaction_id = ?", + ("txn_abc123",), + ) + assert row is not None + assert row["user_id"] == test_user["id"] + assert row["amount_cents"] == 2999 + assert row["currency"] == "EUR" + assert row["status"] == "completed" + + async def test_idempotent_on_duplicate_provider_id(self, db, test_user): + await record_transaction( + user_id=test_user["id"], + provider_transaction_id="txn_dup", + amount_cents=1000, + ) + # Second insert with same provider_transaction_id should be ignored + await record_transaction( + user_id=test_user["id"], + provider_transaction_id="txn_dup", + amount_cents=9999, + ) + + from beanflows.core import fetch_all + rows = await fetch_all( + "SELECT * FROM transactions WHERE provider_transaction_id = ?", + ("txn_dup",), + ) + assert len(rows) == 1 + assert rows[0]["amount_cents"] == 1000 # original value preserved diff --git a/web/tests/test_billing_hooks.py b/web/tests/test_billing_hooks.py new file mode 100644 index 0000000..863cd16 --- /dev/null +++ b/web/tests/test_billing_hooks.py @@ -0,0 +1,122 @@ +""" +Tests for the billing event hook system. +""" +import pytest + +from beanflows.billing.routes import _billing_hooks, _fire_hooks, on_billing_event + + +@pytest.fixture(autouse=True) +def clear_hooks(): + """Ensure hooks are clean before and after each test.""" + _billing_hooks.clear() + yield + _billing_hooks.clear() + + +# ════════════════════════════════════════════════════════════ +# Registration +# ════════════════════════════════════════════════════════════ + +class TestOnBillingEvent: + def test_registers_single_event(self): + @on_billing_event("subscription.activated") + async def my_hook(event_type, data): + pass + + assert "subscription.activated" in _billing_hooks + assert my_hook in _billing_hooks["subscription.activated"] + + def test_registers_multiple_events(self): + @on_billing_event("subscription.activated", "subscription.updated") + async def my_hook(event_type, data): + pass + + assert my_hook in _billing_hooks["subscription.activated"] + assert my_hook in _billing_hooks["subscription.updated"] + + def test_multiple_hooks_per_event(self): + @on_billing_event("subscription.activated") + async def hook_a(event_type, data): + pass + + @on_billing_event("subscription.activated") + async def hook_b(event_type, data): + pass + + assert len(_billing_hooks["subscription.activated"]) == 2 + + def test_decorator_returns_original_function(self): + @on_billing_event("test_event") + async def my_hook(event_type, data): + pass + + assert my_hook.__name__ == "my_hook" + + +# ════════════════════════════════════════════════════════════ +# Firing +# ════════════════════════════════════════════════════════════ + +class TestFireHooks: + async def test_fires_registered_hook(self): + calls = [] + + @on_billing_event("subscription.activated") + async def recorder(event_type, data): + calls.append((event_type, data)) + + await _fire_hooks("subscription.activated", {"id": "sub_123"}) + assert len(calls) == 1 + assert calls[0] == ("subscription.activated", {"id": "sub_123"}) + + async def test_no_hooks_registered_is_noop(self): + # Should not raise + await _fire_hooks("unregistered_event", {"id": "sub_123"}) + + async def test_fires_all_hooks_for_event(self): + calls = [] + + @on_billing_event("subscription.activated") + async def hook_a(event_type, data): + calls.append("a") + + @on_billing_event("subscription.activated") + async def hook_b(event_type, data): + calls.append("b") + + await _fire_hooks("subscription.activated", {}) + assert calls == ["a", "b"] + + +# ════════════════════════════════════════════════════════════ +# Error isolation +# ════════════════════════════════════════════════════════════ + +class TestHookErrorIsolation: + async def test_failing_hook_does_not_block_others(self): + calls = [] + + @on_billing_event("subscription.activated") + async def failing_hook(event_type, data): + raise RuntimeError("boom") + + @on_billing_event("subscription.activated") + async def good_hook(event_type, data): + calls.append("ok") + + # Should not raise despite first hook failing + await _fire_hooks("subscription.activated", {}) + assert calls == ["ok"] + + async def test_failing_hook_is_logged(self, caplog): + @on_billing_event("subscription.activated") + async def bad_hook(event_type, data): + raise ValueError("test error") + + import logging + with caplog.at_level(logging.ERROR): + await _fire_hooks("subscription.activated", {}) + + assert "bad_hook" in caplog.text + assert "test error" in caplog.text diff --git a/web/tests/test_billing_routes.py b/web/tests/test_billing_routes.py index 4830874..7043f12 100644 --- a/web/tests/test_billing_routes.py +++ b/web/tests/test_billing_routes.py @@ -1,12 +1,15 @@ """ Route integration tests for Paddle billing endpoints. -External Paddle API calls mocked with respx. -""" -import json -import httpx +Paddle SDK calls mocked via mock_paddle_client fixture. + +""" + + + +from unittest.mock import MagicMock + import pytest -import respx CHECKOUT_METHOD = "POST" @@ -54,24 +57,16 @@ class TestCheckoutRoute: assert response.status_code in (302, 303, 307) - @respx.mock - async def test_creates_checkout_session(self, auth_client, db, test_user): - - respx.post("https://api.paddle.com/transactions").mock( - return_value=httpx.Response(200, json={ - "data": { - "checkout": { - "url": "https://checkout.paddle.com/test_123" - } - } - }) - ) - + async def test_creates_checkout_session(self, auth_client, db, test_user, mock_paddle_client): + mock_txn = MagicMock() + mock_txn.checkout.url = "https://checkout.paddle.com/test_123" + mock_paddle_client.transactions.create.return_value = mock_txn response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False) - assert response.status_code in (302, 303, 307) + mock_paddle_client.transactions.create.assert_called_once() + async def test_invalid_plan_rejected(self, auth_client, db, test_user): @@ -82,20 +77,13 @@ class TestCheckoutRoute: - @respx.mock - async def test_api_error_propagates(self, auth_client, db, test_user): - - respx.post("https://api.paddle.com/transactions").mock( - return_value=httpx.Response(500, json={"error": "server error"}) - ) - - with pytest.raises(httpx.HTTPStatusError): - + async def test_api_error_propagates(self, auth_client, db, test_user, mock_paddle_client): + mock_paddle_client.transactions.create.side_effect = Exception("API error") + with pytest.raises(Exception, match="API error"): await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}") - # ════════════════════════════════════════════════════════════ # Manage subscription / Portal # ════════════════════════════════════════════════════════════ @@ -110,24 +98,18 @@ class TestManageRoute: response = await auth_client.post("/billing/manage", follow_redirects=False) assert response.status_code in (302, 303, 307) - @respx.mock - async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription): - await create_subscription(test_user["id"], paddle_subscription_id="sub_test") - - respx.get("https://api.paddle.com/subscriptions/sub_test").mock( - return_value=httpx.Response(200, json={ - "data": { - "management_urls": { - "update_payment_method": "https://paddle.com/manage/test_123" - } - } - }) - ) + async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription, mock_paddle_client): + await create_subscription(test_user["id"], provider_subscription_id="sub_test") + mock_sub = MagicMock() + mock_sub.management_urls.update_payment_method = "https://paddle.com/manage/test_123" + mock_paddle_client.subscriptions.get.return_value = mock_sub response = await auth_client.post("/billing/manage", follow_redirects=False) assert response.status_code in (302, 303, 307) + mock_paddle_client.subscriptions.get.assert_called_once_with("sub_test") + @@ -145,18 +127,14 @@ class TestCancelRoute: response = await auth_client.post("/billing/cancel", follow_redirects=False) assert response.status_code in (302, 303, 307) - @respx.mock - async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription): - - await create_subscription(test_user["id"], paddle_subscription_id="sub_test") - - respx.post("https://api.paddle.com/subscriptions/sub_test/cancel").mock( - return_value=httpx.Response(200, json={"data": {}}) - ) + async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription, mock_paddle_client): + await create_subscription(test_user["id"], provider_subscription_id="sub_test") response = await auth_client.post("/billing/cancel", follow_redirects=False) assert response.status_code in (302, 303, 307) + mock_paddle_client.subscriptions.cancel.assert_called_once() + @@ -167,8 +145,9 @@ class TestCancelRoute: # subscription_required decorator # ════════════════════════════════════════════════════════════ -from beanflows.billing.routes import subscription_required -from quart import Blueprint +from quart import Blueprint # noqa: E402 + +from beanflows.auth.routes import subscription_required # noqa: E402 test_bp = Blueprint("test", __name__) diff --git a/web/tests/test_billing_webhooks.py b/web/tests/test_billing_webhooks.py index 37a3251..96e291d 100644 --- a/web/tests/test_billing_webhooks.py +++ b/web/tests/test_billing_webhooks.py @@ -5,13 +5,13 @@ Covers signature verification, event parsing, subscription lifecycle transitions import json import pytest +from conftest import make_webhook_payload, sign_payload + from hypothesis import HealthCheck, given from hypothesis import settings as h_settings from hypothesis import strategies as st -from beanflows.billing.routes import get_subscription - -from conftest import make_webhook_payload, sign_payload +from beanflows.billing.routes import get_billing_customer, get_subscription WEBHOOK_PATH = "/billing/webhook/paddle" @@ -72,18 +72,19 @@ class TestWebhookSignature: async def test_modified_payload_rejected(self, client, db, test_user): + # Paddle SDK Verifier handles tamper detection internally. + # We test signature rejection via test_invalid_signature_rejected above. + # This test verifies the Verifier is actually called by sending + # a payload with an explicitly bad signature. payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"])) payload_bytes = json.dumps(payload).encode() - sig = sign_payload(payload_bytes) - tampered = payload_bytes + b"extra" - # Paddle/LemonSqueezy: HMAC signature verification fails before JSON parsing response = await client.post( WEBHOOK_PATH, - data=tampered, - headers={SIG_HEADER: sig, "Content-Type": "application/json"}, + data=payload_bytes, + headers={SIG_HEADER: "invalid_signature", "Content-Type": "application/json"}, ) - assert response.status_code in (400, 401) + assert response.status_code == 400 async def test_empty_payload_rejected(self, client, db): @@ -105,7 +106,7 @@ class TestWebhookSignature: class TestWebhookSubscriptionActivated: - async def test_creates_subscription(self, client, db, test_user): + async def test_creates_subscription_and_billing_customer(self, client, db, test_user): payload = make_webhook_payload( "subscription.activated", user_id=str(test_user["id"]), @@ -126,10 +127,14 @@ class TestWebhookSubscriptionActivated: assert sub["plan"] == "starter" assert sub["status"] == "active" + bc = await get_billing_customer(test_user["id"]) + assert bc is not None + assert bc["provider_customer_id"] == "ctm_test123" + class TestWebhookSubscriptionUpdated: async def test_updates_subscription_status(self, client, db, test_user, create_subscription): - await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") + await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456") payload = make_webhook_payload( "subscription.updated", @@ -152,7 +157,7 @@ class TestWebhookSubscriptionUpdated: class TestWebhookSubscriptionCanceled: async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription): - await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") + await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456") payload = make_webhook_payload( "subscription.canceled", @@ -174,7 +179,7 @@ class TestWebhookSubscriptionCanceled: class TestWebhookSubscriptionPastDue: async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription): - await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") + await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456") payload = make_webhook_payload( "subscription.past_due", @@ -209,7 +214,7 @@ class TestWebhookSubscriptionPastDue: ]) async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status): if event_type != "subscription.activated": - await create_subscription(test_user["id"], paddle_subscription_id="sub_test456") + await create_subscription(test_user["id"], provider_subscription_id="sub_test456") payload = make_webhook_payload(event_type, user_id=str(test_user["id"])) payload_bytes = json.dumps(payload).encode() diff --git a/web/tests/test_roles.py b/web/tests/test_roles.py new file mode 100644 index 0000000..378d712 --- /dev/null +++ b/web/tests/test_roles.py @@ -0,0 +1,242 @@ +""" +Tests for role-based access control: role_required decorator, grant/revoke/ensure_admin_role, +and admin route protection. +""" +import pytest +from quart import Blueprint + +from beanflows.auth.routes import ( + ensure_admin_role, + grant_role, + revoke_role, + role_required, +) +from beanflows import core + + +# ════════════════════════════════════════════════════════════ +# grant_role / revoke_role +# ════════════════════════════════════════════════════════════ + +class TestGrantRole: + async def test_grants_role(self, db, test_user): + await grant_role(test_user["id"], "admin") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ?", + (test_user["id"],), + ) + assert row is not None + assert row["role"] == "admin" + + async def test_idempotent(self, db, test_user): + await grant_role(test_user["id"], "admin") + await grant_role(test_user["id"], "admin") + rows = await core.fetch_all( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert len(rows) == 1 + + +class TestRevokeRole: + async def test_revokes_existing_role(self, db, test_user): + await grant_role(test_user["id"], "admin") + await revoke_role(test_user["id"], "admin") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert row is None + + async def test_noop_for_missing_role(self, db, test_user): + # Should not raise + await revoke_role(test_user["id"], "nonexistent") + + +# ════════════════════════════════════════════════════════════ +# ensure_admin_role +# ════════════════════════════════════════════════════════════ + +class TestEnsureAdminRole: + async def test_grants_admin_for_listed_email(self, db, test_user): + core.config.ADMIN_EMAILS = ["test@example.com"] + try: + await ensure_admin_role(test_user["id"], "test@example.com") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert row is not None + finally: + core.config.ADMIN_EMAILS = [] + + async def test_skips_for_unlisted_email(self, db, test_user): + core.config.ADMIN_EMAILS = ["boss@example.com"] + try: + await ensure_admin_role(test_user["id"], "test@example.com") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert row is None + finally: + core.config.ADMIN_EMAILS = [] + + async def test_empty_admin_emails_grants_nothing(self, db, test_user): + core.config.ADMIN_EMAILS = [] + await ensure_admin_role(test_user["id"], "test@example.com") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert row is None + + async def test_case_insensitive_matching(self, db, test_user): + core.config.ADMIN_EMAILS = ["test@example.com"] + try: + await ensure_admin_role(test_user["id"], "Test@Example.COM") + row = await core.fetch_one( + "SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'", + (test_user["id"],), + ) + assert row is not None + finally: + core.config.ADMIN_EMAILS = [] + + +# ════════════════════════════════════════════════════════════ +# role_required decorator +# ════════════════════════════════════════════════════════════ + +role_test_bp = Blueprint("role_test", __name__) + + +@role_test_bp.route("/admin-only") +@role_required("admin") +async def admin_only_route(): + return "admin-ok", 200 + + +@role_test_bp.route("/multi-role") +@role_required("admin", "editor") +async def multi_role_route(): + return "multi-ok", 200 + + +class TestRoleRequired: + @pytest.fixture + async def role_app(self, app): + app.register_blueprint(role_test_bp) + return app + + @pytest.fixture + async def role_client(self, role_app): + async with role_app.test_client() as c: + yield c + + async def test_redirects_unauthenticated(self, role_client, db): + response = await role_client.get("/admin-only", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + async def test_rejects_user_without_role(self, role_client, db, test_user): + async with role_client.session_transaction() as sess: + sess["user_id"] = test_user["id"] + + response = await role_client.get("/admin-only", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + async def test_allows_user_with_matching_role(self, role_client, db, test_user): + await grant_role(test_user["id"], "admin") + async with role_client.session_transaction() as sess: + sess["user_id"] = test_user["id"] + + response = await role_client.get("/admin-only") + assert response.status_code == 200 + + async def test_multi_role_allows_any_match(self, role_client, db, test_user): + await grant_role(test_user["id"], "editor") + async with role_client.session_transaction() as sess: + sess["user_id"] = test_user["id"] + + response = await role_client.get("/multi-role") + assert response.status_code == 200 + + async def test_multi_role_rejects_none(self, role_client, db, test_user): + await grant_role(test_user["id"], "viewer") + async with role_client.session_transaction() as sess: + sess["user_id"] = test_user["id"] + + response = await role_client.get("/multi-role", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + +# ════════════════════════════════════════════════════════════ +# Admin route protection +# ════════════════════════════════════════════════════════════ + +class TestAdminRouteProtection: + async def test_admin_index_requires_admin_role(self, auth_client, db): + response = await auth_client.get("/admin/", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + async def test_admin_index_accessible_with_admin_role(self, admin_client, db): + response = await admin_client.get("/admin/") + assert response.status_code == 200 + + async def test_admin_users_requires_admin_role(self, auth_client, db): + response = await auth_client.get("/admin/users", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + async def test_admin_tasks_requires_admin_role(self, auth_client, db): + response = await auth_client.get("/admin/tasks", follow_redirects=False) + assert response.status_code in (302, 303, 307) + + +# ════════════════════════════════════════════════════════════ +# Impersonation +# ════════════════════════════════════════════════════════════ + +class TestImpersonation: + async def test_impersonate_stores_admin_id(self, admin_client, db, test_user): + """Impersonating stores admin's user_id in session['admin_impersonating'].""" + # Create a second user to impersonate + now = "2025-01-01T00:00:00" + other_id = await core.execute( + "INSERT INTO users (email, name, created_at) VALUES (?, ?, ?)", + ("other@example.com", "Other", now), + ) + + async with admin_client.session_transaction() as sess: + sess["csrf_token"] = "test_csrf" + + response = await admin_client.post( + f"/admin/users/{other_id}/impersonate", + form={"csrf_token": "test_csrf"}, + follow_redirects=False, + ) + assert response.status_code in (302, 303, 307) + + async with admin_client.session_transaction() as sess: + assert sess["user_id"] == other_id + assert sess["admin_impersonating"] == test_user["id"] + + async def test_stop_impersonating_restores_admin(self, app, db, test_user, grant_role): + """Stopping impersonation restores the admin's user_id.""" + await grant_role(test_user["id"], "admin") + + async with app.test_client() as c: + async with c.session_transaction() as sess: + sess["user_id"] = 999 # impersonated user + sess["admin_impersonating"] = test_user["id"] + sess["csrf_token"] = "test_csrf" + + response = await c.post( + "/admin/stop-impersonating", + form={"csrf_token": "test_csrf"}, + follow_redirects=False, + ) + assert response.status_code in (302, 303, 307) + + async with c.session_transaction() as sess: + assert sess["user_id"] == test_user["id"] + assert "admin_impersonating" not in sess