Merge branch 'copier-update'

This commit is contained in:
Deeman
2026-02-19 22:46:41 +01:00
22 changed files with 1360 additions and 401 deletions

View File

@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY # Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
_commit: v0.3.0 _commit: v0.4.0
_src_path: git@gitlab.com:deemanone/materia_saas_boilerplate.master.git _src_path: git@gitlab.com:deemanone/materia_saas_boilerplate.master.git
author_email: hendrik@beanflows.coffee author_email: hendrik@beanflows.coffee
author_name: Hendrik Deeman author_name: Hendrik Deeman

73
web/CHANGELOG.md Normal file
View File

@@ -0,0 +1,73 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
### Changed
- **Role-based access control**: `user_roles` table with `role_required()` decorator replaces password-based admin auth
- **Admin is a real user**: admins authenticate via magic links; `ADMIN_EMAILS` env var auto-grants admin role on login
- **Separated billing entities**: `billing_customers` table holds payment provider identity; `subscriptions` table holds only subscription state
- **Multiple subscriptions per user**: dropped UNIQUE constraint on `subscriptions.user_id`; `upsert_subscription` finds by `provider_subscription_id`
### Added
- Simple A/B testing with `@ab_test` decorator and optional Umami `data-tag` integration (`UMAMI_SCRIPT_URL` / `UMAMI_WEBSITE_ID` env vars)
- `user_roles` table and `grant_role()` / `revoke_role()` / `ensure_admin_role()` functions
- `billing_customers` table and `upsert_billing_customer()` / `get_billing_customer()` functions
- `role_required(*roles)` decorator in auth
- `is_admin` template context variable
- Migration `0001_roles_and_billing_customers.py` for existing databases
### Removed
- `ADMIN_PASSWORD` env var and password-based admin login
- `provider_customer_id` column from `subscriptions` table
- `admin/templates/admin/login.html`
### Changed
- **Provider-agnostic schema**: generic `provider_customer_id` / `provider_subscription_id` columns replace provider-prefixed names (`stripe_customer_id`, `paddle_customer_id`, `lemonsqueezy_customer_id`) — eliminates all Jinja conditionals from schema, SQL helpers, and route code
- **Consolidated `subscription_required` decorator**: single implementation in `auth/routes.py` supporting both plan and status checks, reads from eager-loaded `g.subscription` (zero extra queries)
- **Eager-loaded `g.subscription`**: `load_user` in `app.py` now fetches user + subscription in a single JOIN; available in all routes and templates via `g.subscription`
### Added
- `transactions` table for recording payment/refund events with idempotent `record_transaction()` helper
- Billing event hook system: `on_billing_event()` decorator and `_fire_hooks()` for domain code to react to subscription changes; errors are logged and never cause webhook 500s
### Removed
- Duplicate `subscription_required` decorator from `billing/routes.py` (consolidated in `auth/routes.py`)
- `get_user_with_subscription()` from `auth/routes.py` (replaced by eager-loaded `g.subscription`)
### Changed
- **Email SDK migration**: replaced raw httpx calls with official `resend` SDK in `core.py`
- Added `from_addr` parameter to `send_email()` for multi-address support
- Added `EMAIL_ADDRESSES` dict for named sender addresses (transactional, etc.)
- **Paddle SDK migration**: replaced raw httpx calls with official `paddle-python-sdk` in `billing/routes.py`
- Checkout, manage, cancel routes now use typed SDK methods (`PaddleClient`, `CreateTransaction`)
- Webhook verification uses SDK's `Verifier` instead of hand-rolled HMAC
- Added `PADDLE_ENVIRONMENT` config for sandbox/production toggling
- Added `_paddle_client()` helper factory
- **Dependencies**: `resend` replaces `httpx` for email; `paddle-python-sdk` replaces `httpx` for Paddle billing; `httpx` now only included for LemonSqueezy projects
- Worker `send_email` task handler now passes through `from_addr`
### Added
- `scripts/setup_paddle.py` — CLI script to create Paddle products/prices programmatically (Paddle projects only)
### Changed
- **Pico CSS → Tailwind CSS v4** — full design system migration across all templates
- Standalone Tailwind CLI binary (no Node.js) with `make css-build` / `make css-watch`
- Brand theme with component classes (`.btn`, `.card`, `.form-input`, `.table`, `.badge`, `.flash`, etc.)
- Self-hosted Commit Mono font for monospace data display
- Docker multi-stage build: CSS compiled in dedicated stage before Python build
### Removed
- Pico CSS CDN dependency
- `custom.css` (replaced by Tailwind `input.css` with `@layer components`)
- JetBrains Mono font (replaced by self-hosted Commit Mono)
### Fixed
- Admin template collision: namespaced admin templates under `admin/` subdirectory to prevent Quart's template loader from resolving auth's `login.html` or dashboard's `index.html` instead of admin's
- Admin user detail: `stripe_customer_id` hardcoded regardless of payment provider — now uses provider-aware Copier conditional (Stripe/Paddle/LemonSqueezy)
### Added
- Initial project scaffolded from quart_saas_boilerplate

View File

@@ -1,3 +1,13 @@
# CSS build stage (Tailwind standalone CLI, no Node.js)
FROM debian:bookworm-slim AS css-build
ADD https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 /usr/local/bin/tailwindcss
RUN chmod +x /usr/local/bin/tailwindcss
WORKDIR /app
COPY src/ ./src/
RUN tailwindcss -i ./src/beanflows/static/css/input.css \
-o ./src/beanflows/static/css/output.css --minify
# Build stage # Build stage
FROM python:3.12-slim AS build FROM python:3.12-slim AS build
COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/ COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/
@@ -15,6 +25,7 @@ RUN useradd -m -u 1000 appuser
WORKDIR /app WORKDIR /app
RUN mkdir -p /app/data && chown -R appuser:appuser /app RUN mkdir -p /app/data && chown -R appuser:appuser /app
COPY --from=build --chown=appuser:appuser /app . COPY --from=build --chown=appuser:appuser /app .
COPY --from=css-build /app/src/beanflows/static/css/output.css ./src/beanflows/static/css/output.css
USER appuser USER appuser
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
ENV DATABASE_PATH=/app/data/app.db ENV DATABASE_PATH=/app/data/app.db

View File

@@ -21,16 +21,16 @@ services:
command: replicate -config /etc/litestream.yml command: replicate -config /etc/litestream.yml
volumes: volumes:
- app-data:/app/data - app-data:/app/data
- ./beanflows/litestream.yml:/etc/litestream.yml:ro - ./litestream.yml:/etc/litestream.yml:ro
# ── Blue slot ───────────────────────────────────────────── # ── Blue slot ─────────────────────────────────────────────
blue-app: blue-app:
profiles: ["blue"] profiles: ["blue"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:
@@ -47,10 +47,10 @@ services:
blue-worker: blue-worker:
profiles: ["blue"] profiles: ["blue"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
command: python -m beanflows.worker command: python -m beanflows.worker
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:
@@ -61,10 +61,10 @@ services:
blue-scheduler: blue-scheduler:
profiles: ["blue"] profiles: ["blue"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
command: python -m beanflows.worker scheduler command: python -m beanflows.worker scheduler
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:
@@ -77,9 +77,9 @@ services:
green-app: green-app:
profiles: ["green"] profiles: ["green"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:
@@ -96,10 +96,10 @@ services:
green-worker: green-worker:
profiles: ["green"] profiles: ["green"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
command: python -m beanflows.worker command: python -m beanflows.worker
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:
@@ -110,10 +110,10 @@ services:
green-scheduler: green-scheduler:
profiles: ["green"] profiles: ["green"]
build: build:
context: ./beanflows context: .
restart: unless-stopped restart: unless-stopped
command: python -m beanflows.worker scheduler command: python -m beanflows.worker scheduler
env_file: ./beanflows/.env env_file: ./.env
environment: environment:
- DATABASE_PATH=/app/data/app.db - DATABASE_PATH=/app/data/app.db
volumes: volumes:

View File

@@ -12,8 +12,9 @@ dependencies = [
"aiosqlite>=0.19.0", "aiosqlite>=0.19.0",
"duckdb>=1.0.0", "duckdb>=1.0.0",
"httpx>=0.27.0", "httpx>=0.27.0",
"resend>=2.22.0",
"python-dotenv>=1.0.0", "python-dotenv>=1.0.0",
"paddle-python-sdk>=1.13.0",
"itsdangerous>=2.1.0", "itsdangerous>=2.1.0",
"jinja2>=3.1.0", "jinja2>=3.1.0",
"hypercorn>=0.17.0", "hypercorn>=0.17.0",

View File

@@ -12,24 +12,24 @@ from .core import close_db, config, get_csrf_token, init_db, setup_request_id
def create_app() -> Quart: def create_app() -> Quart:
"""Create and configure the Quart application.""" """Create and configure the Quart application."""
# Get package directory for templates # Get package directory for templates
pkg_dir = Path(__file__).parent pkg_dir = Path(__file__).parent
app = Quart( app = Quart(
__name__, __name__,
template_folder=str(pkg_dir / "templates"), template_folder=str(pkg_dir / "templates"),
static_folder=str(pkg_dir / "static"), static_folder=str(pkg_dir / "static"),
) )
app.secret_key = config.SECRET_KEY app.secret_key = config.SECRET_KEY
# Session config # Session config
app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG
app.config["SESSION_COOKIE_HTTPONLY"] = True app.config["SESSION_COOKIE_HTTPONLY"] = True
app.config["SESSION_COOKIE_SAMESITE"] = "Lax" app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS
# Database lifecycle # Database lifecycle
@app.before_serving @app.before_serving
async def startup(): async def startup():
@@ -41,7 +41,7 @@ def create_app() -> Quart:
async def shutdown(): async def shutdown():
close_analytics_db() close_analytics_db()
await close_db() await close_db()
# Security headers # Security headers
@app.after_request @app.after_request
async def add_security_headers(response): async def add_security_headers(response):
@@ -51,16 +51,42 @@ def create_app() -> Quart:
if not config.DEBUG: if not config.DEBUG:
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains" response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response return response
# Load current user before each request # Load current user + subscription + roles before each request
@app.before_request @app.before_request
async def load_user(): async def load_user():
g.user = None g.user = None
g.subscription = None
user_id = session.get("user_id") user_id = session.get("user_id")
if user_id: if user_id:
from .auth.routes import get_user_by_id from .core import fetch_one as _fetch_one
g.user = await get_user_by_id(user_id) row = await _fetch_one(
"""SELECT u.*,
bc.provider_customer_id,
(SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv,
s.id AS sub_id, s.plan, s.status AS sub_status,
s.provider_subscription_id, s.current_period_end
FROM users u
LEFT JOIN billing_customers bc ON bc.user_id = u.id
LEFT JOIN subscriptions s ON s.id = (
SELECT id FROM subscriptions
WHERE user_id = u.id
ORDER BY created_at DESC LIMIT 1
)
WHERE u.id = ? AND u.deleted_at IS NULL""",
(user_id,),
)
if row:
g.user = dict(row)
g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else []
if row["sub_id"]:
g.subscription = {
"id": row["sub_id"], "plan": row["plan"],
"status": row["sub_status"],
"provider_subscription_id": row["provider_subscription_id"],
"current_period_end": row["current_period_end"],
}
# Template context globals # Template context globals
@app.context_processor @app.context_processor
def inject_globals(): def inject_globals():
@@ -68,10 +94,14 @@ def create_app() -> Quart:
return { return {
"config": config, "config": config,
"user": g.get("user"), "user": g.get("user"),
"subscription": g.get("subscription"),
"is_admin": "admin" in (g.get("user") or {}).get("roles", []),
"now": datetime.utcnow(), "now": datetime.utcnow(),
"csrf_token": get_csrf_token, "csrf_token": get_csrf_token,
"ab_variant": getattr(g, "ab_variant", None),
"ab_tag": getattr(g, "ab_tag", None),
} }
# Health check # Health check
@app.route("/health") @app.route("/health")
async def health(): async def health():
@@ -94,7 +124,7 @@ def create_app() -> Quart:
result["duckdb"] = "not configured" result["duckdb"] = "not configured"
status_code = 200 if result["status"] == "healthy" else 500 status_code = 200 if result["status"] == "healthy" else 500
return result, status_code return result, status_code
# Register blueprints # Register blueprints
from .admin.routes import bp as admin_bp from .admin.routes import bp as admin_bp
from .api.routes import bp as api_bp from .api.routes import bp as api_bp
@@ -102,17 +132,17 @@ def create_app() -> Quart:
from .billing.routes import bp as billing_bp from .billing.routes import bp as billing_bp
from .dashboard.routes import bp as dashboard_bp from .dashboard.routes import bp as dashboard_bp
from .public.routes import bp as public_bp from .public.routes import bp as public_bp
app.register_blueprint(public_bp) app.register_blueprint(public_bp)
app.register_blueprint(auth_bp) app.register_blueprint(auth_bp)
app.register_blueprint(dashboard_bp) app.register_blueprint(dashboard_bp)
app.register_blueprint(billing_bp) app.register_blueprint(billing_bp)
app.register_blueprint(api_bp, url_prefix="/api/v1") app.register_blueprint(api_bp, url_prefix="/api/v1")
app.register_blueprint(admin_bp) app.register_blueprint(admin_bp)
# Request ID tracking # Request ID tracking
setup_request_id(app) setup_request_id(app)
return app return app

View File

@@ -71,7 +71,7 @@ async def get_valid_token(token: str) -> dict | None:
"""Get token if valid and not expired.""" """Get token if valid and not expired."""
return await fetch_one( return await fetch_one(
""" """
SELECT at.*, u.email SELECT at.*, u.email
FROM auth_tokens at FROM auth_tokens at
JOIN users u ON u.id = at.user_id JOIN users u ON u.id = at.user_id
WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL
@@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None:
) )
async def get_user_with_subscription(user_id: int) -> dict | None:
"""Get user with their active subscription info."""
return await fetch_one(
"""
SELECT u.*, s.plan, s.status as sub_status, s.current_period_end
FROM users u
LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active'
WHERE u.id = ? AND u.deleted_at IS NULL
""",
(user_id,)
)
# ============================================================================= # =============================================================================
# Decorators # Decorators
# ============================================================================= # =============================================================================
@@ -116,24 +103,69 @@ def login_required(f):
return decorated return decorated
def subscription_required(plans: list[str] = None): def role_required(*roles):
"""Require active subscription, optionally of specific plan(s).""" """Require user to have at least one of the given roles."""
def decorator(f):
@wraps(f)
async def decorated(*args, **kwargs):
if not g.get("user"):
await flash("Please sign in to continue.", "warning")
return redirect(url_for("auth.login", next=request.path))
user_roles = g.user.get("roles", [])
if not any(r in user_roles for r in roles):
await flash("You don't have permission to access that page.", "error")
return redirect(url_for("dashboard.index"))
return await f(*args, **kwargs)
return decorated
return decorator
async def grant_role(user_id: int, role: str) -> None:
"""Grant a role to a user (idempotent)."""
await execute(
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
(user_id, role),
)
async def revoke_role(user_id: int, role: str) -> None:
"""Revoke a role from a user."""
await execute(
"DELETE FROM user_roles WHERE user_id = ? AND role = ?",
(user_id, role),
)
async def ensure_admin_role(user_id: int, email: str) -> None:
"""Grant admin role if email is in ADMIN_EMAILS."""
if email.lower() in config.ADMIN_EMAILS:
await grant_role(user_id, "admin")
def subscription_required(
plans: list[str] = None,
allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"),
):
"""Require active subscription, optionally of specific plan(s) and/or statuses.
Reads from g.subscription (eager-loaded in load_user) — zero extra queries.
"""
def decorator(f): def decorator(f):
@wraps(f) @wraps(f)
async def decorated(*args, **kwargs): async def decorated(*args, **kwargs):
if not g.get("user"): if not g.get("user"):
await flash("Please sign in to continue.", "warning") await flash("Please sign in to continue.", "warning")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
user = await get_user_with_subscription(g.user["id"]) sub = g.get("subscription")
if not user or not user.get("plan"): if not sub or sub["status"] not in allowed:
await flash("Please subscribe to access this feature.", "warning") await flash("Please subscribe to access this feature.", "warning")
return redirect(url_for("billing.pricing")) return redirect(url_for("billing.pricing"))
if plans and user["plan"] not in plans: if plans and sub["plan"] not in plans:
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning") await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
return redirect(url_for("billing.pricing")) return redirect(url_for("billing.pricing"))
return await f(*args, **kwargs) return await f(*args, **kwargs)
return decorated return decorated
return decorator return decorator
@@ -149,33 +181,33 @@ async def login():
"""Login page - request magic link.""" """Login page - request magic link."""
if g.get("user"): if g.get("user"):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if request.method == "POST": if request.method == "POST":
form = await request.form form = await request.form
email = form.get("email", "").strip().lower() email = form.get("email", "").strip().lower()
if not email or "@" not in email: if not email or "@" not in email:
await flash("Please enter a valid email address.", "error") await flash("Please enter a valid email address.", "error")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
# Get or create user # Get or create user
user = await get_user_by_email(email) user = await get_user_by_email(email)
if not user: if not user:
user_id = await create_user(email) user_id = await create_user(email)
else: else:
user_id = user["id"] user_id = user["id"]
# Create magic link token # Create magic link token
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
await create_auth_token(user_id, token) await create_auth_token(user_id, token)
# Queue email # Queue email
from ..worker import enqueue from ..worker import enqueue
await enqueue("send_magic_link", {"email": email, "token": token}) await enqueue("send_magic_link", {"email": email, "token": token})
await flash("Check your email for the sign-in link!", "success") await flash("Check your email for the sign-in link!", "success")
return redirect(url_for("auth.magic_link_sent", email=email)) return redirect(url_for("auth.magic_link_sent", email=email))
return await render_template("login.html") return await render_template("login.html")
@@ -185,39 +217,39 @@ async def signup():
"""Signup page - same as login but with different messaging.""" """Signup page - same as login but with different messaging."""
if g.get("user"): if g.get("user"):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
plan = request.args.get("plan", "free") plan = request.args.get("plan", "free")
if request.method == "POST": if request.method == "POST":
form = await request.form form = await request.form
email = form.get("email", "").strip().lower() email = form.get("email", "").strip().lower()
selected_plan = form.get("plan", "free") selected_plan = form.get("plan", "free")
if not email or "@" not in email: if not email or "@" not in email:
await flash("Please enter a valid email address.", "error") await flash("Please enter a valid email address.", "error")
return redirect(url_for("auth.signup", plan=selected_plan)) return redirect(url_for("auth.signup", plan=selected_plan))
# Check if user exists # Check if user exists
user = await get_user_by_email(email) user = await get_user_by_email(email)
if user: if user:
await flash("Account already exists. Please sign in.", "info") await flash("Account already exists. Please sign in.", "info")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
# Create user # Create user
user_id = await create_user(email) user_id = await create_user(email)
# Create magic link token # Create magic link token
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
await create_auth_token(user_id, token) await create_auth_token(user_id, token)
# Queue emails # Queue emails
from ..worker import enqueue from ..worker import enqueue
await enqueue("send_magic_link", {"email": email, "token": token}) await enqueue("send_magic_link", {"email": email, "token": token})
await enqueue("send_welcome", {"email": email}) await enqueue("send_welcome", {"email": email})
await flash("Check your email to complete signup!", "success") await flash("Check your email to complete signup!", "success")
return redirect(url_for("auth.magic_link_sent", email=email)) return redirect(url_for("auth.magic_link_sent", email=email))
return await render_template("signup.html", plan=plan) return await render_template("signup.html", plan=plan)
@@ -225,29 +257,32 @@ async def signup():
async def verify(): async def verify():
"""Verify magic link token.""" """Verify magic link token."""
token = request.args.get("token") token = request.args.get("token")
if not token: if not token:
await flash("Invalid or expired link.", "error") await flash("Invalid or expired link.", "error")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
token_data = await get_valid_token(token) token_data = await get_valid_token(token)
if not token_data: if not token_data:
await flash("Invalid or expired link. Please request a new one.", "error") await flash("Invalid or expired link. Please request a new one.", "error")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
# Mark token as used # Mark token as used
await mark_token_used(token_data["id"]) await mark_token_used(token_data["id"])
# Update last login # Update last login
await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat()) await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat())
# Set session # Set session
session.permanent = True session.permanent = True
session["user_id"] = token_data["user_id"] session["user_id"] = token_data["user_id"]
# Auto-grant admin role if email is in ADMIN_EMAILS
await ensure_admin_role(token_data["user_id"], token_data["email"])
await flash("Successfully signed in!", "success") await flash("Successfully signed in!", "success")
# Redirect to intended page or dashboard # Redirect to intended page or dashboard
next_url = request.args.get("next", url_for("dashboard.index")) next_url = request.args.get("next", url_for("dashboard.index"))
return redirect(next_url) return redirect(next_url)
@@ -274,18 +309,21 @@ async def dev_login():
"""Instant login for development. Only works in DEBUG mode.""" """Instant login for development. Only works in DEBUG mode."""
if not config.DEBUG: if not config.DEBUG:
return "Not available", 404 return "Not available", 404
email = request.args.get("email", "dev@localhost") email = request.args.get("email", "dev@localhost")
user = await get_user_by_email(email) user = await get_user_by_email(email)
if not user: if not user:
user_id = await create_user(email) user_id = await create_user(email)
else: else:
user_id = user["id"] user_id = user["id"]
session.permanent = True session.permanent = True
session["user_id"] = user_id session["user_id"] = user_id
# Auto-grant admin role if email is in ADMIN_EMAILS
await ensure_admin_role(user_id, email)
await flash(f"Dev login as {email}", "success") await flash(f"Dev login as {email}", "success")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
@@ -296,19 +334,19 @@ async def resend():
"""Resend magic link.""" """Resend magic link."""
form = await request.form form = await request.form
email = form.get("email", "").strip().lower() email = form.get("email", "").strip().lower()
if not email: if not email:
await flash("Email address required.", "error") await flash("Email address required.", "error")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
user = await get_user_by_email(email) user = await get_user_by_email(email)
if user: if user:
token = secrets.token_urlsafe(32) token = secrets.token_urlsafe(32)
await create_auth_token(user["id"], token) await create_auth_token(user["id"], token)
from ..worker import enqueue from ..worker import enqueue
await enqueue("send_magic_link", {"email": email, "token": token}) await enqueue("send_magic_link", {"email": email, "token": token})
# Always show success (don't reveal if email exists) # Always show success (don't reveal if email exists)
await flash("If that email is registered, we've sent a new link.", "success") await flash("If that email is registered, we've sent a new link.", "success")
return redirect(url_for("auth.magic_link_sent", email=email)) return redirect(url_for("auth.magic_link_sent", email=email))

View File

@@ -2,16 +2,19 @@
Core infrastructure: database, config, email, and shared utilities. Core infrastructure: database, config, email, and shared utilities.
""" """
import os import os
import random
import secrets import secrets
import hashlib import hashlib
import hmac import hmac
import aiosqlite import aiosqlite
import resend
import httpx import httpx
from pathlib import Path from pathlib import Path
from functools import wraps from functools import wraps
from datetime import datetime, timedelta from datetime import datetime, timedelta
from contextvars import ContextVar from contextvars import ContextVar
from quart import request, session, g from quart import g, make_response, request, session
from dotenv import load_dotenv from dotenv import load_dotenv
# web/.env is three levels up from web/src/beanflows/core.py # web/.env is three levels up from web/src/beanflows/core.py
@@ -26,27 +29,35 @@ class Config:
SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production") SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production")
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001") BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001")
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db") DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db")
MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15")) MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15"))
SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30")) SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30"))
PAYMENT_PROVIDER: str = "paddle" PAYMENT_PROVIDER: str = "paddle"
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "") PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "") PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
PADDLE_PRICES: dict = { PADDLE_PRICES: dict = {
"starter": os.getenv("PADDLE_PRICE_STARTER", ""), "starter": os.getenv("PADDLE_PRICE_STARTER", ""),
"pro": os.getenv("PADDLE_PRICE_PRO", ""), "pro": os.getenv("PADDLE_PRICE_PRO", ""),
} }
UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "")
UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "")
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "") RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com") EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
ADMIN_EMAILS: list[str] = [
e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip()
]
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100")) RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60")) RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
PLAN_FEATURES: dict = { PLAN_FEATURES: dict = {
"free": ["dashboard", "coffee_only", "limited_history"], "free": ["dashboard", "coffee_only", "limited_history"],
"starter": ["dashboard", "coffee_only", "full_history", "export", "api"], "starter": ["dashboard", "coffee_only", "full_history", "export", "api"],
@@ -74,10 +85,10 @@ async def init_db(path: str = None) -> None:
global _db global _db
db_path = path or config.DATABASE_PATH db_path = path or config.DATABASE_PATH
Path(db_path).parent.mkdir(parents=True, exist_ok=True) Path(db_path).parent.mkdir(parents=True, exist_ok=True)
_db = await aiosqlite.connect(db_path) _db = await aiosqlite.connect(db_path)
_db.row_factory = aiosqlite.Row _db.row_factory = aiosqlite.Row
await _db.execute("PRAGMA journal_mode=WAL") await _db.execute("PRAGMA journal_mode=WAL")
await _db.execute("PRAGMA foreign_keys=ON") await _db.execute("PRAGMA foreign_keys=ON")
await _db.execute("PRAGMA busy_timeout=5000") await _db.execute("PRAGMA busy_timeout=5000")
@@ -137,11 +148,11 @@ async def execute_many(sql: str, params_list: list[tuple]) -> None:
class transaction: class transaction:
"""Async context manager for transactions.""" """Async context manager for transactions."""
async def __aenter__(self): async def __aenter__(self):
self.db = await get_db() self.db = await get_db()
return self.db return self.db
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is None: if exc_type is None:
await self.db.commit() await self.db.commit()
@@ -153,25 +164,32 @@ class transaction:
# Email # Email
# ============================================================================= # =============================================================================
async def send_email(to: str, subject: str, html: str, text: str = None) -> bool: EMAIL_ADDRESSES = {
"""Send email via Resend API.""" "transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>",
}
async def send_email(
to: str, subject: str, html: str, text: str = None, from_addr: str = None
) -> bool:
"""Send email via Resend SDK."""
if not config.RESEND_API_KEY: if not config.RESEND_API_KEY:
print(f"[EMAIL] Would send to {to}: {subject}") print(f"[EMAIL] Would send to {to}: {subject}")
return True return True
async with httpx.AsyncClient() as client: resend.api_key = config.RESEND_API_KEY
response = await client.post( try:
"https://api.resend.com/emails", resend.Emails.send({
headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"}, "from": from_addr or config.EMAIL_FROM,
json={ "to": to,
"from": config.EMAIL_FROM, "subject": subject,
"to": to, "html": html,
"subject": subject, "text": text or html,
"html": html, })
"text": text or html, return True
}, except Exception as e:
) print(f"[EMAIL] Error sending to {to}: {e}")
return response.status_code == 200 return False
# ============================================================================= # =============================================================================
# CSRF Protection # CSRF Protection
@@ -214,34 +232,34 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t
window = window or config.RATE_LIMIT_WINDOW window = window or config.RATE_LIMIT_WINDOW
now = datetime.utcnow() now = datetime.utcnow()
window_start = now - timedelta(seconds=window) window_start = now - timedelta(seconds=window)
# Clean old entries and count recent # Clean old entries and count recent
await execute( await execute(
"DELETE FROM rate_limits WHERE key = ? AND timestamp < ?", "DELETE FROM rate_limits WHERE key = ? AND timestamp < ?",
(key, window_start.isoformat()) (key, window_start.isoformat())
) )
result = await fetch_one( result = await fetch_one(
"SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?", "SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?",
(key, window_start.isoformat()) (key, window_start.isoformat())
) )
count = result["count"] if result else 0 count = result["count"] if result else 0
info = { info = {
"limit": limit, "limit": limit,
"remaining": max(0, limit - count - 1), "remaining": max(0, limit - count - 1),
"reset": int((window_start + timedelta(seconds=window)).timestamp()), "reset": int((window_start + timedelta(seconds=window)).timestamp()),
} }
if count >= limit: if count >= limit:
return False, info return False, info
# Record this request # Record this request
await execute( await execute(
"INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)", "INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)",
(key, now.isoformat()) (key, now.isoformat())
) )
return True, info return True, info
@@ -254,13 +272,13 @@ def rate_limit(limit: int = None, window: int = None, key_func=None):
key = key_func() key = key_func()
else: else:
key = f"ip:{request.remote_addr}" key = f"ip:{request.remote_addr}"
allowed, info = await check_rate_limit(key, limit, window) allowed, info = await check_rate_limit(key, limit, window)
if not allowed: if not allowed:
response = {"error": "Rate limit exceeded", **info} response = {"error": "Rate limit exceeded", **info}
return response, 429 return response, 429
return await f(*args, **kwargs) return await f(*args, **kwargs)
return decorated return decorated
return decorator return decorator
@@ -284,7 +302,7 @@ def setup_request_id(app):
rid = request.headers.get("X-Request-ID") or secrets.token_hex(8) rid = request.headers.get("X-Request-ID") or secrets.token_hex(8)
request_id_var.set(rid) request_id_var.set(rid)
g.request_id = rid g.request_id = rid
@app.after_request @app.after_request
async def add_request_id_header(response): async def add_request_id_header(response):
response.headers["X-Request-ID"] = get_request_id() response.headers["X-Request-ID"] = get_request_id()
@@ -294,13 +312,11 @@ def setup_request_id(app):
# Webhook Signature Verification # Webhook Signature Verification
# ============================================================================= # =============================================================================
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool: def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
"""Verify HMAC-SHA256 webhook signature.""" """Verify HMAC-SHA256 webhook signature."""
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
return hmac.compare_digest(signature, expected) return hmac.compare_digest(signature, expected)
# ============================================================================= # =============================================================================
# Soft Delete Helpers # Soft Delete Helpers
# ============================================================================= # =============================================================================
@@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int:
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?", f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
(cutoff,) (cutoff,)
) )
# =============================================================================
# A/B Testing
# =============================================================================
def ab_test(experiment: str, variants: tuple = ("control", "treatment")):
"""Assign visitor to an A/B test variant via cookie, tag Umami pageviews."""
def decorator(f):
@wraps(f)
async def wrapper(*args, **kwargs):
cookie_key = f"ab_{experiment}"
assigned = request.cookies.get(cookie_key)
if assigned not in variants:
assigned = random.choice(variants)
g.ab_variant = assigned
g.ab_tag = f"{experiment}-{assigned}"
response = await make_response(await f(*args, **kwargs))
response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60)
return response
return wrapper
return decorator

View File

@@ -1,51 +1,95 @@
""" """
Simple migration runner. Runs schema.sql against the database. Sequential migration runner.
"""
import sqlite3 Replays all migrations in order. All databases — fresh and existing —
from pathlib import Path go through the same path. No schema.sql fast-path.
import os
import sys - Scans versions/ for NNNN_*.py files and runs unapplied ones in order
- Each migration has an up(conn) function receiving an uncommitted connection
- All pending migrations share a single transaction (batch atomicity)
"""
import importlib
import os
import re
import sqlite3
import sys
from pathlib import Path
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
VERSIONS_DIR = Path(__file__).parent / "versions"
VERSION_RE = re.compile(r"^(\d{4})_.+\.py$")
def migrate(): # Derived from the package path: …/src/<slug>/migrations/migrate.py
"""Run migrations.""" _PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject"
# Get database path from env or default
db_path = os.getenv("DATABASE_PATH", "data/app.db")
def _discover_versions():
# Ensure directory exists """Return sorted list of version file stems."""
if not VERSIONS_DIR.is_dir():
return []
versions = []
for f in sorted(VERSIONS_DIR.iterdir()):
if VERSION_RE.match(f.name):
versions.append(f.stem)
return versions
def migrate(db_path=None):
if db_path is None:
db_path = os.getenv("DATABASE_PATH", "data/app.db")
Path(db_path).parent.mkdir(parents=True, exist_ok=True) Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# Read schema
schema_path = Path(__file__).parent / "schema.sql"
schema = schema_path.read_text()
# Connect and execute
conn = sqlite3.connect(db_path) conn = sqlite3.connect(db_path)
# Enable WAL mode
conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON") conn.execute("PRAGMA foreign_keys=ON")
# Run schema # Ensure tracking table exists before anything else
conn.executescript(schema) conn.execute("""
CREATE TABLE IF NOT EXISTS _migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
)
""")
conn.commit() conn.commit()
print(f"✓ Migrations complete: {db_path}") versions = _discover_versions()
applied = {
# Show tables row[0]
for row in conn.execute("SELECT name FROM _migrations").fetchall()
}
pending = [v for v in versions if v not in applied]
if pending:
for name in pending:
print(f" Applying {name}...")
mod = importlib.import_module(
f"{_PACKAGE}.migrations.versions.{name}"
)
mod.up(conn)
conn.execute(
"INSERT INTO _migrations (name) VALUES (?)", (name,)
)
conn.commit()
print(f"✓ Applied {len(pending)} migration(s): {db_path}")
else:
print(f"✓ All migrations already applied: {db_path}")
# Show tables (excluding internal sqlite/fts tables)
cursor = conn.execute( cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" "SELECT name FROM sqlite_master WHERE type='table'"
" AND name NOT LIKE 'sqlite_%'"
" ORDER BY name"
) )
tables = [row[0] for row in cursor.fetchall()] tables = [row[0] for row in cursor.fetchall()]
print(f" Tables: {', '.join(tables)}") print(f" Tables: {', '.join(tables)}")
conn.close() conn.close()

View File

@@ -1,6 +1,13 @@
-- BeanFlows Database Schema -- BeanFlows Database Schema
-- Run with: python -m beanflows.migrations.migrate -- Run with: python -m beanflows.migrations.migrate
-- Migration tracking
CREATE TABLE IF NOT EXISTS _migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Users -- Users
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens (
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token); CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id); CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
-- User Roles
CREATE TABLE IF NOT EXISTS user_roles (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role TEXT NOT NULL,
granted_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(user_id, role)
);
CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id);
CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role);
-- Billing Customers (payment provider identity, separate from subscriptions)
CREATE TABLE IF NOT EXISTS billing_customers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
provider_customer_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id);
CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id);
-- Subscriptions -- Subscriptions
CREATE TABLE IF NOT EXISTS subscriptions ( CREATE TABLE IF NOT EXISTS subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id), user_id INTEGER NOT NULL REFERENCES users(id),
plan TEXT NOT NULL DEFAULT 'free', plan TEXT NOT NULL DEFAULT 'free',
status TEXT NOT NULL DEFAULT 'free', status TEXT NOT NULL DEFAULT 'free',
provider_subscription_id TEXT,
paddle_customer_id TEXT,
paddle_subscription_id TEXT,
current_period_end TEXT, current_period_end TEXT,
created_at TEXT NOT NULL, created_at TEXT NOT NULL,
updated_at TEXT updated_at TEXT
); );
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id); CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id);
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id); -- Transactions
CREATE TABLE IF NOT EXISTS transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
subscription_id INTEGER REFERENCES subscriptions(id),
provider_transaction_id TEXT UNIQUE,
type TEXT NOT NULL DEFAULT 'payment',
amount_cents INTEGER,
currency TEXT DEFAULT 'USD',
status TEXT NOT NULL DEFAULT 'pending',
metadata TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id);
CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id);
-- API Keys -- API Keys
CREATE TABLE IF NOT EXISTS api_keys ( CREATE TABLE IF NOT EXISTS api_keys (
@@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks (
); );
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at); CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
-- Items (example domain entity - replace with your domain)
CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
name TEXT NOT NULL,
data TEXT,
created_at TEXT NOT NULL,
updated_at TEXT,
deleted_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id);
CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at);
-- Full-text search for items (optional)
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
name,
data,
content='items',
content_rowid='id'
);
-- FTS triggers
CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
END;
CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
END;
CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
END;

View File

View File

@@ -0,0 +1,92 @@
"""
Create Paddle products and prices for BeanFlows.
Run once per environment (sandbox, then production).
Prints resulting price IDs as a .env snippet.
Usage:
uv run python -m beanflows.scripts.setup_paddle
"""
import os
import sys
from dotenv import load_dotenv
from paddle_billing import Client as PaddleClient
from paddle_billing import Environment, Options
from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory
from paddle_billing.Resources.Prices.Operations import CreatePrice
from paddle_billing.Resources.Products.Operations import CreateProduct
load_dotenv()
PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "")
PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
if not PADDLE_API_KEY:
print("ERROR: Set PADDLE_API_KEY in .env first")
sys.exit(1)
PRODUCTS = [
# Subscriptions
{
"name": "Starter",
"env_key": "PADDLE_PRICE_STARTER",
"price": 900,
"currency": CurrencyCode.USD,
"interval": "month",
"type": "subscription",
},
{
"name": "Pro",
"env_key": "PADDLE_PRICE_PRO",
"price": 2900,
"currency": CurrencyCode.USD,
"interval": "month",
"type": "subscription",
},
]
def main():
env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION
paddle = PaddleClient(PADDLE_API_KEY, options=Options(env))
print(f"Creating products in {PADDLE_ENVIRONMENT}...\n")
env_lines = []
for spec in PRODUCTS:
# Create product
product = paddle.products.create(CreateProduct(
name=spec["name"],
tax_category=TaxCategory.Standard,
))
print(f" Product: {spec['name']} -> {product.id}")
# Create price
price_kwargs = {
"description": spec["name"],
"product_id": product.id,
"unit_price": Money(str(spec["price"]), spec["currency"]),
}
if spec["type"] == "subscription":
from paddle_billing.Entities.Shared import TimePeriod
price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1)
price = paddle.prices.create(CreatePrice(**price_kwargs))
print(f" Price: {spec['env_key']} = {price.id}")
env_lines.append(f"{spec['env_key']}={price.id}")
print("\n# --- .env snippet ---")
for line in env_lines:
print(line)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,90 @@
This Font Software is licensed under the SIL Open Font License, Version 1.1.
This license is copied below, and is also available with a FAQ at:
http://scripts.sil.org/OFL
-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
-----------------------------------------------------------
PREAMBLE
The goals of the Open Font License (OFL) are to stimulate worldwide
development of collaborative font projects, to support the font creation
efforts of academic and linguistic communities, and to provide a free and
open framework in which fonts may be shared and improved in partnership
with others.
The OFL allows the licensed fonts to be used, studied, modified and
redistributed freely as long as they are not sold by themselves. The
fonts, including any derivative works, can be bundled, embedded,
redistributed and/or sold with any software provided that any reserved
names are not used by derivative works. The fonts and derivatives,
however, cannot be released under any other type of license. The
requirement for fonts to remain under this license does not apply
to any document created using the fonts or their derivatives.
DEFINITIONS
"Font Software" refers to the set of files released by the Copyright
Holder(s) under this license and clearly marked as such. This may
include source files, build scripts and documentation.
"Reserved Font Name" refers to any names specified as such after the
copyright statement(s).
"Original Version" refers to the collection of Font Software components as
distributed by the Copyright Holder(s).
"Modified Version" refers to any derivative made by adding to, deleting,
or substituting -- in part or in whole -- any of the components of the
Original Version, by changing formats or by porting the Font Software to a
new environment.
"Author" refers to any designer, engineer, programmer, technical
writer or other person who contributed to the Font Software.
PERMISSION & CONDITIONS
Permission is hereby granted, free of charge, to any person obtaining
a copy of the Font Software, to use, study, copy, merge, embed, modify,
redistribute, and sell modified and unmodified copies of the Font
Software, subject to the following conditions:
1) Neither the Font Software nor any of its individual components,
in Original or Modified Versions, may be sold by itself.
2) Original or Modified Versions of the Font Software may be bundled,
redistributed and/or sold with any software, provided that each copy
contains the above copyright notice and this license. These can be
included either as stand-alone text files, human-readable headers or
in the appropriate machine-readable metadata fields within text or
binary files as long as those fields can be easily viewed by the user.
3) No Modified Version of the Font Software may use the Reserved Font
Name(s) unless explicit written permission is granted by the corresponding
Copyright Holder. This restriction only applies to the primary font name as
presented to the users.
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
Software shall not be used to promote, endorse or advertise any
Modified Version, except to acknowledge the contribution(s) of the
Copyright Holder(s) and the Author(s) or with their explicit written
permission.
5) The Font Software, modified or unmodified, in part or in whole,
must be distributed entirely under this license, and must not be
distributed under any other license. The requirement for fonts to
remain under this license does not apply to any document created
using the Font Software.
TERMINATION
This license becomes null and void if any of the above conditions are
not met.
DISCLAIMER
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
OTHER DEALINGS IN THE FONT SOFTWARE.

View File

@@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email
HANDLERS: dict[str, callable] = {} HANDLERS: dict[str, callable] = {}
def _email_wrap(body: str) -> str:
"""Wrap email body in a branded layout with inline CSS."""
return f"""\
<!DOCTYPE html>
<html lang="en">
<head><meta charset="utf-8"></head>
<body style="margin:0;padding:0;background-color:#F8FAFC;font-family:'Inter',Helvetica,Arial,sans-serif;">
<table width="100%" cellpadding="0" cellspacing="0" style="background-color:#F8FAFC;padding:40px 0;">
<tr><td align="center">
<table width="480" cellpadding="0" cellspacing="0" style="background-color:#FFFFFF;border-radius:8px;border:1px solid #E2E8F0;overflow:hidden;">
<!-- Header -->
<tr><td style="background-color:#0F172A;padding:24px 32px;">
<span style="color:#FFFFFF;font-size:18px;font-weight:700;letter-spacing:-0.02em;">{config.APP_NAME}</span>
</td></tr>
<!-- Body -->
<tr><td style="padding:32px;color:#475569;font-size:15px;line-height:1.6;">
{body}
</td></tr>
<!-- Footer -->
<tr><td style="padding:20px 32px;border-top:1px solid #E2E8F0;text-align:center;">
<span style="color:#94A3B8;font-size:12px;">&copy; {config.APP_NAME} &middot; You received this because you have an account.</span>
</td></tr>
</table>
</td></tr>
</table>
</body>
</html>"""
def _email_button(url: str, label: str) -> str:
"""Render a branded CTA button for email."""
return (
f'<table cellpadding="0" cellspacing="0" style="margin:24px 0;">'
f'<tr><td style="background-color:#3B82F6;border-radius:6px;text-align:center;">'
f'<a href="{url}" style="display:inline-block;padding:12px 28px;'
f'color:#FFFFFF;font-size:15px;font-weight:600;text-decoration:none;">'
f'{label}</a></td></tr></table>'
)
def task(name: str): def task(name: str):
"""Decorator to register a task handler.""" """Decorator to register a task handler."""
def decorator(f): def decorator(f):
@@ -46,7 +86,7 @@ async def get_pending_tasks(limit: int = 10) -> list[dict]:
now = datetime.utcnow().isoformat() now = datetime.utcnow().isoformat()
return await fetch_all( return await fetch_all(
""" """
SELECT * FROM tasks SELECT * FROM tasks
WHERE status = 'pending' AND run_at <= ? WHERE status = 'pending' AND run_at <= ?
ORDER BY run_at ASC ORDER BY run_at ASC
LIMIT ? LIMIT ?
@@ -66,15 +106,15 @@ async def mark_complete(task_id: int) -> None:
async def mark_failed(task_id: int, error: str, retries: int) -> None: async def mark_failed(task_id: int, error: str, retries: int) -> None:
"""Mark task as failed, schedule retry if attempts remain.""" """Mark task as failed, schedule retry if attempts remain."""
max_retries = 3 max_retries = 3
if retries < max_retries: if retries < max_retries:
# Exponential backoff: 1min, 5min, 25min # Exponential backoff: 1min, 5min, 25min
delay = timedelta(minutes=5 ** retries) delay = timedelta(minutes=5 ** retries)
run_at = datetime.utcnow() + delay run_at = datetime.utcnow() + delay
await execute( await execute(
""" """
UPDATE tasks UPDATE tasks
SET status = 'pending', error = ?, retries = ?, run_at = ? SET status = 'pending', error = ?, retries = ?, run_at = ?
WHERE id = ? WHERE id = ?
""", """,
@@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None:
subject=payload["subject"], subject=payload["subject"],
html=payload["html"], html=payload["html"],
text=payload.get("text"), text=payload.get("text"),
from_addr=payload.get("from_addr"),
) )
@@ -106,35 +147,37 @@ async def handle_send_email(payload: dict) -> None:
async def handle_send_magic_link(payload: dict) -> None: async def handle_send_magic_link(payload: dict) -> None:
"""Send magic link email.""" """Send magic link email."""
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}" link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
html = f""" body = (
<h2>Sign in to {config.APP_NAME}</h2> f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Sign in to {config.APP_NAME}</h2>'
<p>Click the link below to sign in:</p> f"<p>Click the button below to sign in. This link expires in "
<p><a href="{link}">{link}</a></p> f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>"
<p>This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p> f"{_email_button(link, 'Sign In')}"
<p>If you didn't request this, you can safely ignore this email.</p> f'<p style="font-size:13px;color:#94A3B8;">If the button doesn\'t work, copy and paste this URL into your browser:</p>'
""" f'<p style="font-size:13px;color:#94A3B8;word-break:break-all;">{link}</p>'
f'<p style="font-size:13px;color:#94A3B8;">If you didn\'t request this, you can safely ignore this email.</p>'
)
await send_email( await send_email(
to=payload["email"], to=payload["email"],
subject=f"Sign in to {config.APP_NAME}", subject=f"Sign in to {config.APP_NAME}",
html=html, html=_email_wrap(body),
) )
@task("send_welcome") @task("send_welcome")
async def handle_send_welcome(payload: dict) -> None: async def handle_send_welcome(payload: dict) -> None:
"""Send welcome email to new user.""" """Send welcome email to new user."""
html = f""" body = (
<h2>Welcome to {config.APP_NAME}!</h2> f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Welcome to {config.APP_NAME}!</h2>'
<p>Thanks for signing up. We're excited to have you.</p> f"<p>Thanks for signing up. We're excited to have you.</p>"
<p><a href="{config.BASE_URL}/dashboard">Go to your dashboard</a></p> f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}'
""" )
await send_email( await send_email(
to=payload["email"], to=payload["email"],
subject=f"Welcome to {config.APP_NAME}", subject=f"Welcome to {config.APP_NAME}",
html=html, html=_email_wrap(body),
) )
@@ -173,12 +216,12 @@ async def process_task(task: dict) -> None:
task_name = task["task_name"] task_name = task["task_name"]
task_id = task["id"] task_id = task["id"]
retries = task.get("retries", 0) retries = task.get("retries", 0)
handler = HANDLERS.get(task_name) handler = HANDLERS.get(task_name)
if not handler: if not handler:
await mark_failed(task_id, f"Unknown task: {task_name}", retries) await mark_failed(task_id, f"Unknown task: {task_name}", retries)
return return
try: try:
payload = json.loads(task["payload"]) if task["payload"] else {} payload = json.loads(task["payload"]) if task["payload"] else {}
await handler(payload) await handler(payload)
@@ -194,17 +237,17 @@ async def run_worker(poll_interval: float = 1.0) -> None:
"""Main worker loop.""" """Main worker loop."""
print("[WORKER] Starting...") print("[WORKER] Starting...")
await init_db() await init_db()
while True: while True:
try: try:
tasks = await get_pending_tasks(limit=10) tasks = await get_pending_tasks(limit=10)
for task in tasks: for task in tasks:
await process_task(task) await process_task(task)
if not tasks: if not tasks:
await asyncio.sleep(poll_interval) await asyncio.sleep(poll_interval)
except Exception as e: except Exception as e:
print(f"[WORKER] Error: {e}") print(f"[WORKER] Error: {e}")
await asyncio.sleep(poll_interval * 5) await asyncio.sleep(poll_interval * 5)
@@ -214,16 +257,16 @@ async def run_scheduler() -> None:
"""Schedule periodic cleanup tasks.""" """Schedule periodic cleanup tasks."""
print("[SCHEDULER] Starting...") print("[SCHEDULER] Starting...")
await init_db() await init_db()
while True: while True:
try: try:
# Schedule cleanup tasks every hour # Schedule cleanup tasks every hour
await enqueue("cleanup_expired_tokens") await enqueue("cleanup_expired_tokens")
await enqueue("cleanup_rate_limits") await enqueue("cleanup_rate_limits")
await enqueue("cleanup_old_tasks") await enqueue("cleanup_old_tasks")
await asyncio.sleep(3600) # 1 hour await asyncio.sleep(3600) # 1 hour
except Exception as e: except Exception as e:
print(f"[SCHEDULER] Error: {e}") print(f"[SCHEDULER] Error: {e}")
await asyncio.sleep(60) await asyncio.sleep(60)
@@ -231,8 +274,8 @@ async def run_scheduler() -> None:
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
if len(sys.argv) > 1 and sys.argv[1] == "scheduler": if len(sys.argv) > 1 and sys.argv[1] == "scheduler":
asyncio.run(run_scheduler()) asyncio.run(run_scheduler())
else: else:
asyncio.run(run_worker()) asyncio.run(run_worker())

View File

@@ -10,9 +10,11 @@ from unittest.mock import AsyncMock, patch
import aiosqlite import aiosqlite
import pytest import pytest
from beanflows import analytics, core
from beanflows import core
from beanflows.app import create_app from beanflows.app import create_app
SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql" SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql"
@@ -44,9 +46,7 @@ async def db():
async def app(db): async def app(db):
"""Quart app with DB already initialized (init_db/close_db patched to no-op).""" """Quart app with DB already initialized (init_db/close_db patched to no-op)."""
with patch.object(core, "init_db", new_callable=AsyncMock), \ with patch.object(core, "init_db", new_callable=AsyncMock), \
patch.object(core, "close_db", new_callable=AsyncMock), \ patch.object(core, "close_db", new_callable=AsyncMock):
patch.object(analytics, "open_analytics_db"), \
patch.object(analytics, "close_analytics_db"):
application = create_app() application = create_app()
application.config["TESTING"] = True application.config["TESTING"] = True
yield application yield application
@@ -92,22 +92,17 @@ def create_subscription(db):
user_id: int, user_id: int,
plan: str = "pro", plan: str = "pro",
status: str = "active", status: str = "active",
provider_subscription_id: str = "sub_test456",
paddle_customer_id: str = "ctm_test123",
paddle_subscription_id: str = "sub_test456",
current_period_end: str = "2025-03-01T00:00:00Z", current_period_end: str = "2025-03-01T00:00:00Z",
) -> int: ) -> int:
now = datetime.utcnow().isoformat() now = datetime.utcnow().isoformat()
async with db.execute( async with db.execute(
"""INSERT INTO subscriptions """INSERT INTO subscriptions
(user_id, plan, status, paddle_customer_id, (user_id, plan, status,
paddle_subscription_id, current_period_end, created_at, updated_at) provider_subscription_id, current_period_end, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""", VALUES (?, ?, ?, ?, ?, ?, ?)""",
(user_id, plan, status, paddle_customer_id, paddle_subscription_id, (user_id, plan, status, provider_subscription_id,
current_period_end, now, now), current_period_end, now, now),
) as cursor: ) as cursor:
sub_id = cursor.lastrowid sub_id = cursor.lastrowid
await db.commit() await db.commit()
@@ -115,6 +110,48 @@ def create_subscription(db):
return _create return _create
# ── Billing Customers ───────────────────────────────────────
@pytest.fixture
def create_billing_customer(db):
"""Factory: create a billing_customers row for a user."""
async def _create(user_id: int, provider_customer_id: str = "cust_test123") -> int:
async with db.execute(
"""INSERT INTO billing_customers (user_id, provider_customer_id)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""",
(user_id, provider_customer_id),
) as cursor:
row_id = cursor.lastrowid
await db.commit()
return row_id
return _create
# ── Roles ───────────────────────────────────────────────────
@pytest.fixture
def grant_role(db):
"""Factory: grant a role to a user."""
async def _grant(user_id: int, role: str) -> None:
await db.execute(
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
(user_id, role),
)
await db.commit()
return _grant
@pytest.fixture
async def admin_client(app, test_user, grant_role):
"""Test client with admin role and session['user_id'] pre-set."""
await grant_role(test_user["id"], "admin")
async with app.test_client() as c:
async with c.session_transaction() as sess:
sess["user_id"] = test_user["id"]
yield c
# ── Config ─────────────────────────────────────────────────── # ── Config ───────────────────────────────────────────────────
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@@ -127,6 +164,7 @@ def patch_config():
"PADDLE_API_KEY": "test_api_key_123", "PADDLE_API_KEY": "test_api_key_123",
"PADDLE_WEBHOOK_SECRET": "whsec_test_secret", "PADDLE_WEBHOOK_SECRET": "whsec_test_secret",
"PADDLE_ENVIRONMENT": "sandbox",
"PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"}, "PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"},
"BASE_URL": "http://localhost:5000", "BASE_URL": "http://localhost:5000",
@@ -147,6 +185,32 @@ def patch_config():
# ── Webhook helpers ────────────────────────────────────────── # ── Webhook helpers ──────────────────────────────────────────
@pytest.fixture(autouse=True)
def mock_paddle_verifier(monkeypatch):
"""Mock Paddle's webhook Verifier to accept test payloads."""
def mock_verify(self, payload, secret, signature):
if not signature or signature == "invalid_signature":
raise ValueError("Invalid signature")
monkeypatch.setattr(
"paddle_billing.Notifications.Verifier.verify",
mock_verify,
)
@pytest.fixture
def mock_paddle_client(monkeypatch):
"""Mock _paddle_client() to return a fake PaddleClient."""
from unittest.mock import MagicMock
mock_client = MagicMock()
monkeypatch.setattr(
"beanflows.billing.routes._paddle_client",
lambda: mock_client,
)
return mock_client
def make_webhook_payload( def make_webhook_payload(
event_type: str, event_type: str,
subscription_id: str = "sub_test456", subscription_id: str = "sub_test456",
@@ -172,76 +236,8 @@ def make_webhook_payload(
} }
def sign_payload(payload_bytes: bytes, secret: str = "whsec_test_secret") -> str: def sign_payload(payload_bytes: bytes) -> str:
"""Compute HMAC-SHA256 signature for a webhook payload.""" """Return a dummy signature for Paddle webhook tests (Verifier is mocked)."""
return hmac.new(secret.encode(), payload_bytes, hashlib.sha256).hexdigest() return "ts=1234567890;h1=dummy_signature"
# ── Analytics mock data ──────────────────────────────────────
MOCK_TIME_SERIES = [
{"market_year": 2018, "Production": 165000, "Exports": 115000, "Imports": 105000,
"Ending_Stocks": 33000, "Total_Distribution": 160000},
{"market_year": 2019, "Production": 168000, "Exports": 118000, "Imports": 108000,
"Ending_Stocks": 34000, "Total_Distribution": 163000},
{"market_year": 2020, "Production": 170000, "Exports": 120000, "Imports": 110000,
"Ending_Stocks": 35000, "Total_Distribution": 165000},
{"market_year": 2021, "Production": 175000, "Exports": 125000, "Imports": 115000,
"Ending_Stocks": 36000, "Total_Distribution": 170000},
{"market_year": 2022, "Production": 172000, "Exports": 122000, "Imports": 112000,
"Ending_Stocks": 34000, "Total_Distribution": 168000},
]
MOCK_TOP_COUNTRIES = [
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022, "Production": 65000},
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, "Production": 30000},
{"country_name": "Colombia", "country_code": "CO", "market_year": 2022, "Production": 14000},
]
MOCK_STU_TREND = [
{"market_year": 2020, "Stock_to_Use_Ratio_pct": 21.2},
{"market_year": 2021, "Stock_to_Use_Ratio_pct": 21.1},
{"market_year": 2022, "Stock_to_Use_Ratio_pct": 20.2},
]
MOCK_BALANCE = [
{"market_year": 2020, "Production": 170000, "Total_Distribution": 165000, "Supply_Demand_Balance": 5000},
{"market_year": 2021, "Production": 175000, "Total_Distribution": 170000, "Supply_Demand_Balance": 5000},
{"market_year": 2022, "Production": 172000, "Total_Distribution": 168000, "Supply_Demand_Balance": 4000},
]
MOCK_YOY = [
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022,
"Production": 65000, "Production_YoY_pct": -3.5},
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022,
"Production": 30000, "Production_YoY_pct": 2.1},
]
MOCK_COMMODITIES = [
{"commodity_code": 711100, "commodity_name": "Coffee, Green"},
{"commodity_code": 222000, "commodity_name": "Soybeans"},
]
@pytest.fixture
def mock_analytics():
"""Patch all analytics query functions with mock data."""
with patch.object(analytics, "get_global_time_series", new_callable=AsyncMock,
return_value=MOCK_TIME_SERIES), \
patch.object(analytics, "get_top_countries", new_callable=AsyncMock,
return_value=MOCK_TOP_COUNTRIES), \
patch.object(analytics, "get_stock_to_use_trend", new_callable=AsyncMock,
return_value=MOCK_STU_TREND), \
patch.object(analytics, "get_supply_demand_balance", new_callable=AsyncMock,
return_value=MOCK_BALANCE), \
patch.object(analytics, "get_production_yoy_by_country", new_callable=AsyncMock,
return_value=MOCK_YOY), \
patch.object(analytics, "get_country_comparison", new_callable=AsyncMock,
return_value=[]), \
patch.object(analytics, "get_available_commodities", new_callable=AsyncMock,
return_value=MOCK_COMMODITIES), \
patch.object(analytics, "fetch_analytics", new_callable=AsyncMock,
return_value=[{"result": 1}]):
yield

View File

@@ -9,10 +9,13 @@ from hypothesis import strategies as st
from beanflows.billing.routes import ( from beanflows.billing.routes import (
can_access_feature, can_access_feature,
get_billing_customer,
get_subscription, get_subscription,
get_subscription_by_provider_id, get_subscription_by_provider_id,
is_within_limits, is_within_limits,
record_transaction,
update_subscription_status, update_subscription_status,
upsert_billing_customer,
upsert_subscription, upsert_subscription,
) )
from beanflows.core import config from beanflows.core import config
@@ -45,7 +48,6 @@ class TestUpsertSubscription:
user_id=test_user["id"], user_id=test_user["id"],
plan="pro", plan="pro",
status="active", status="active",
provider_customer_id="cust_abc",
provider_subscription_id="sub_xyz", provider_subscription_id="sub_xyz",
current_period_end="2025-06-01T00:00:00Z", current_period_end="2025-06-01T00:00:00Z",
) )
@@ -53,39 +55,53 @@ class TestUpsertSubscription:
row = await get_subscription(test_user["id"]) row = await get_subscription(test_user["id"])
assert row["plan"] == "pro" assert row["plan"] == "pro"
assert row["status"] == "active" assert row["status"] == "active"
assert row["provider_subscription_id"] == "sub_xyz"
assert row["paddle_customer_id"] == "cust_abc"
assert row["paddle_subscription_id"] == "sub_xyz"
assert row["current_period_end"] == "2025-06-01T00:00:00Z" assert row["current_period_end"] == "2025-06-01T00:00:00Z"
async def test_update_existing_subscription(self, db, test_user, create_subscription): async def test_update_existing_by_provider_subscription_id(self, db, test_user):
original_id = await create_subscription( """upsert finds existing by provider_subscription_id, not user_id."""
test_user["id"], plan="starter", status="active", await upsert_subscription(
user_id=test_user["id"],
paddle_subscription_id="sub_old", plan="starter",
status="active",
provider_subscription_id="sub_same",
) )
returned_id = await upsert_subscription( returned_id = await upsert_subscription(
user_id=test_user["id"], user_id=test_user["id"],
plan="pro", plan="pro",
status="active", status="active",
provider_customer_id="cust_new", provider_subscription_id="sub_same",
provider_subscription_id="sub_new",
) )
assert returned_id == original_id
row = await get_subscription(test_user["id"]) row = await get_subscription(test_user["id"])
assert row["plan"] == "pro" assert row["plan"] == "pro"
assert row["provider_subscription_id"] == "sub_same"
assert row["paddle_subscription_id"] == "sub_new" async def test_different_provider_id_creates_new(self, db, test_user):
"""Different provider_subscription_id creates a new row (multi-sub support)."""
await upsert_subscription(
user_id=test_user["id"],
plan="starter",
status="active",
provider_subscription_id="sub_first",
)
await upsert_subscription(
user_id=test_user["id"],
plan="pro",
status="active",
provider_subscription_id="sub_second",
)
from beanflows.core import fetch_all
rows = await fetch_all(
"SELECT * FROM subscriptions WHERE user_id = ? ORDER BY created_at",
(test_user["id"],),
)
assert len(rows) == 2
async def test_upsert_with_none_period_end(self, db, test_user): async def test_upsert_with_none_period_end(self, db, test_user):
await upsert_subscription( await upsert_subscription(
user_id=test_user["id"], user_id=test_user["id"],
plan="pro", plan="pro",
status="active", status="active",
provider_customer_id="cust_1",
provider_subscription_id="sub_1", provider_subscription_id="sub_1",
current_period_end=None, current_period_end=None,
) )
@@ -93,6 +109,28 @@ class TestUpsertSubscription:
assert row["current_period_end"] is None assert row["current_period_end"] is None
# ════════════════════════════════════════════════════════════
# upsert_billing_customer / get_billing_customer
# ════════════════════════════════════════════════════════════
class TestUpsertBillingCustomer:
async def test_creates_billing_customer(self, db, test_user):
await upsert_billing_customer(test_user["id"], "cust_abc")
row = await get_billing_customer(test_user["id"])
assert row is not None
assert row["provider_customer_id"] == "cust_abc"
async def test_updates_existing_customer(self, db, test_user):
await upsert_billing_customer(test_user["id"], "cust_old")
await upsert_billing_customer(test_user["id"], "cust_new")
row = await get_billing_customer(test_user["id"])
assert row["provider_customer_id"] == "cust_new"
async def test_get_returns_none_for_unknown_user(self, db):
row = await get_billing_customer(99999)
assert row is None
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
# get_subscription_by_provider_id # get_subscription_by_provider_id
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
@@ -102,10 +140,8 @@ class TestGetSubscriptionByProviderId:
result = await get_subscription_by_provider_id("nonexistent") result = await get_subscription_by_provider_id("nonexistent")
assert result is None assert result is None
async def test_finds_by_provider_subscription_id(self, db, test_user, create_subscription):
async def test_finds_by_paddle_subscription_id(self, db, test_user, create_subscription): await create_subscription(test_user["id"], provider_subscription_id="sub_findme")
await create_subscription(test_user["id"], paddle_subscription_id="sub_findme")
result = await get_subscription_by_provider_id("sub_findme") result = await get_subscription_by_provider_id("sub_findme")
assert result is not None assert result is not None
assert result["user_id"] == test_user["id"] assert result["user_id"] == test_user["id"]
@@ -117,18 +153,14 @@ class TestGetSubscriptionByProviderId:
class TestUpdateSubscriptionStatus: class TestUpdateSubscriptionStatus:
async def test_updates_status(self, db, test_user, create_subscription): async def test_updates_status(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_upd")
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_upd")
await update_subscription_status("sub_upd", status="cancelled") await update_subscription_status("sub_upd", status="cancelled")
row = await get_subscription(test_user["id"]) row = await get_subscription(test_user["id"])
assert row["status"] == "cancelled" assert row["status"] == "cancelled"
assert row["updated_at"] is not None assert row["updated_at"] is not None
async def test_updates_extra_fields(self, db, test_user, create_subscription): async def test_updates_extra_fields(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], provider_subscription_id="sub_extra")
await create_subscription(test_user["id"], paddle_subscription_id="sub_extra")
await update_subscription_status( await update_subscription_status(
"sub_extra", "sub_extra",
status="active", status="active",
@@ -141,9 +173,7 @@ class TestUpdateSubscriptionStatus:
assert row["current_period_end"] == "2026-01-01T00:00:00Z" assert row["current_period_end"] == "2026-01-01T00:00:00Z"
async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription): async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], provider_subscription_id="sub_known", status="active")
await create_subscription(test_user["id"], paddle_subscription_id="sub_known", status="active")
await update_subscription_status("sub_unknown", status="expired") await update_subscription_status("sub_unknown", status="expired")
row = await get_subscription(test_user["id"]) row = await get_subscription(test_user["id"])
assert row["status"] == "active" # unchanged assert row["status"] == "active" # unchanged
@@ -155,22 +185,22 @@ class TestUpdateSubscriptionStatus:
class TestCanAccessFeature: class TestCanAccessFeature:
async def test_no_subscription_gets_free_features(self, db, test_user): async def test_no_subscription_gets_free_features(self, db, test_user):
assert await can_access_feature(test_user["id"], "dashboard") is True assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is False assert await can_access_feature(test_user["id"], "export") is False
assert await can_access_feature(test_user["id"], "api") is False assert await can_access_feature(test_user["id"], "api") is False
async def test_active_pro_gets_all_features(self, db, test_user, create_subscription): async def test_active_pro_gets_all_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="active") await create_subscription(test_user["id"], plan="pro", status="active")
assert await can_access_feature(test_user["id"], "dashboard") is True assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is True assert await can_access_feature(test_user["id"], "export") is True
assert await can_access_feature(test_user["id"], "api") is True assert await can_access_feature(test_user["id"], "api") is True
assert await can_access_feature(test_user["id"], "priority_support") is True assert await can_access_feature(test_user["id"], "priority_support") is True
async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription): async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="starter", status="active") await create_subscription(test_user["id"], plan="starter", status="active")
assert await can_access_feature(test_user["id"], "dashboard") is True assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is True assert await can_access_feature(test_user["id"], "export") is True
assert await can_access_feature(test_user["id"], "all_commodities") is False assert await can_access_feature(test_user["id"], "api") is False
async def test_cancelled_still_has_features(self, db, test_user, create_subscription): async def test_cancelled_still_has_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="cancelled") await create_subscription(test_user["id"], plan="pro", status="cancelled")
@@ -183,7 +213,7 @@ class TestCanAccessFeature:
async def test_expired_falls_back_to_free(self, db, test_user, create_subscription): async def test_expired_falls_back_to_free(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="expired") await create_subscription(test_user["id"], plan="pro", status="expired")
assert await can_access_feature(test_user["id"], "api") is False assert await can_access_feature(test_user["id"], "api") is False
assert await can_access_feature(test_user["id"], "dashboard") is True assert await can_access_feature(test_user["id"], "basic") is True
async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription): async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="past_due") await create_subscription(test_user["id"], plan="pro", status="past_due")
@@ -203,30 +233,28 @@ class TestCanAccessFeature:
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
class TestIsWithinLimits: class TestIsWithinLimits:
async def test_free_user_no_api_calls(self, db, test_user): async def test_free_user_within_limits(self, db, test_user):
assert await is_within_limits(test_user["id"], "api_calls", 0) is False assert await is_within_limits(test_user["id"], "items", 50) is True
async def test_free_user_commodity_limit(self, db, test_user): async def test_free_user_at_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "commodities", 0) is True assert await is_within_limits(test_user["id"], "items", 100) is False
assert await is_within_limits(test_user["id"], "commodities", 1) is False
async def test_free_user_history_limit(self, db, test_user): async def test_free_user_over_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "history_years", 4) is True assert await is_within_limits(test_user["id"], "items", 150) is False
assert await is_within_limits(test_user["id"], "history_years", 5) is False
async def test_pro_unlimited(self, db, test_user, create_subscription): async def test_pro_unlimited(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="active") await create_subscription(test_user["id"], plan="pro", status="active")
assert await is_within_limits(test_user["id"], "commodities", 999999) is True assert await is_within_limits(test_user["id"], "items", 999999) is True
assert await is_within_limits(test_user["id"], "api_calls", 999999) is True assert await is_within_limits(test_user["id"], "api_calls", 999999) is True
async def test_starter_limits(self, db, test_user, create_subscription): async def test_starter_limits(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="starter", status="active") await create_subscription(test_user["id"], plan="starter", status="active")
assert await is_within_limits(test_user["id"], "api_calls", 9999) is True assert await is_within_limits(test_user["id"], "items", 999) is True
assert await is_within_limits(test_user["id"], "api_calls", 10000) is False assert await is_within_limits(test_user["id"], "items", 1000) is False
async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription): async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="expired") await create_subscription(test_user["id"], plan="pro", status="expired")
assert await is_within_limits(test_user["id"], "api_calls", 0) is False assert await is_within_limits(test_user["id"], "items", 100) is False
async def test_unknown_resource_returns_false(self, db, test_user): async def test_unknown_resource_returns_false(self, db, test_user):
assert await is_within_limits(test_user["id"], "unicorns", 0) is False assert await is_within_limits(test_user["id"], "unicorns", 0) is False
@@ -238,7 +266,7 @@ class TestIsWithinLimits:
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"] STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"]
FEATURES = ["dashboard", "export", "api", "priority_support"] FEATURES = ["basic", "export", "api", "priority_support"]
ACTIVE_STATUSES = {"active", "on_trial", "cancelled"} ACTIVE_STATUSES = {"active", "on_trial", "cancelled"}
@@ -282,9 +310,9 @@ async def test_plan_feature_matrix(db, test_user, create_subscription, plan, fea
@pytest.mark.parametrize("plan", PLANS) @pytest.mark.parametrize("plan", PLANS)
@pytest.mark.parametrize("resource,at_limit", [ @pytest.mark.parametrize("resource,at_limit", [
("commodities", 1), ("items", 100),
("commodities", 65), ("items", 1000),
("api_calls", 0), ("api_calls", 1000),
("api_calls", 10000), ("api_calls", 10000),
]) ])
async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit): async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit):
@@ -307,11 +335,11 @@ async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resou
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
class TestLimitsHypothesis: class TestLimitsHypothesis:
@given(count=st.integers(min_value=0, max_value=100)) @given(count=st.integers(min_value=0, max_value=10000))
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture]) @h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_free_limit_boundary_commodities(self, db, test_user, count): async def test_free_limit_boundary_items(self, db, test_user, count):
result = await is_within_limits(test_user["id"], "commodities", count) result = await is_within_limits(test_user["id"], "items", count)
assert result == (count < 1) assert result == (count < 100)
@given(count=st.integers(min_value=0, max_value=100000)) @given(count=st.integers(min_value=0, max_value=100000))
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture]) @h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
@@ -319,7 +347,56 @@ class TestLimitsHypothesis:
# Use upsert to avoid duplicate inserts across Hypothesis examples # Use upsert to avoid duplicate inserts across Hypothesis examples
await upsert_subscription( await upsert_subscription(
user_id=test_user["id"], plan="pro", status="active", user_id=test_user["id"], plan="pro", status="active",
provider_customer_id="cust_hyp", provider_subscription_id="sub_hyp", provider_subscription_id="sub_hyp",
) )
result = await is_within_limits(test_user["id"], "commodities", count) result = await is_within_limits(test_user["id"], "items", count)
assert result is True assert result is True
# ════════════════════════════════════════════════════════════
# record_transaction
# ════════════════════════════════════════════════════════════
class TestRecordTransaction:
async def test_inserts_transaction(self, db, test_user):
txn_id = await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_abc123",
type="payment",
amount_cents=2999,
currency="EUR",
status="completed",
)
assert txn_id is not None and txn_id > 0
from beanflows.core import fetch_one
row = await fetch_one(
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
("txn_abc123",),
)
assert row is not None
assert row["user_id"] == test_user["id"]
assert row["amount_cents"] == 2999
assert row["currency"] == "EUR"
assert row["status"] == "completed"
async def test_idempotent_on_duplicate_provider_id(self, db, test_user):
await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_dup",
amount_cents=1000,
)
# Second insert with same provider_transaction_id should be ignored
await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_dup",
amount_cents=9999,
)
from beanflows.core import fetch_all
rows = await fetch_all(
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
("txn_dup",),
)
assert len(rows) == 1
assert rows[0]["amount_cents"] == 1000 # original value preserved

View File

@@ -0,0 +1,122 @@
"""
Tests for the billing event hook system.
"""
import pytest
from beanflows.billing.routes import _billing_hooks, _fire_hooks, on_billing_event
@pytest.fixture(autouse=True)
def clear_hooks():
"""Ensure hooks are clean before and after each test."""
_billing_hooks.clear()
yield
_billing_hooks.clear()
# ════════════════════════════════════════════════════════════
# Registration
# ════════════════════════════════════════════════════════════
class TestOnBillingEvent:
def test_registers_single_event(self):
@on_billing_event("subscription.activated")
async def my_hook(event_type, data):
pass
assert "subscription.activated" in _billing_hooks
assert my_hook in _billing_hooks["subscription.activated"]
def test_registers_multiple_events(self):
@on_billing_event("subscription.activated", "subscription.updated")
async def my_hook(event_type, data):
pass
assert my_hook in _billing_hooks["subscription.activated"]
assert my_hook in _billing_hooks["subscription.updated"]
def test_multiple_hooks_per_event(self):
@on_billing_event("subscription.activated")
async def hook_a(event_type, data):
pass
@on_billing_event("subscription.activated")
async def hook_b(event_type, data):
pass
assert len(_billing_hooks["subscription.activated"]) == 2
def test_decorator_returns_original_function(self):
@on_billing_event("test_event")
async def my_hook(event_type, data):
pass
assert my_hook.__name__ == "my_hook"
# ════════════════════════════════════════════════════════════
# Firing
# ════════════════════════════════════════════════════════════
class TestFireHooks:
async def test_fires_registered_hook(self):
calls = []
@on_billing_event("subscription.activated")
async def recorder(event_type, data):
calls.append((event_type, data))
await _fire_hooks("subscription.activated", {"id": "sub_123"})
assert len(calls) == 1
assert calls[0] == ("subscription.activated", {"id": "sub_123"})
async def test_no_hooks_registered_is_noop(self):
# Should not raise
await _fire_hooks("unregistered_event", {"id": "sub_123"})
async def test_fires_all_hooks_for_event(self):
calls = []
@on_billing_event("subscription.activated")
async def hook_a(event_type, data):
calls.append("a")
@on_billing_event("subscription.activated")
async def hook_b(event_type, data):
calls.append("b")
await _fire_hooks("subscription.activated", {})
assert calls == ["a", "b"]
# ════════════════════════════════════════════════════════════
# Error isolation
# ════════════════════════════════════════════════════════════
class TestHookErrorIsolation:
async def test_failing_hook_does_not_block_others(self):
calls = []
@on_billing_event("subscription.activated")
async def failing_hook(event_type, data):
raise RuntimeError("boom")
@on_billing_event("subscription.activated")
async def good_hook(event_type, data):
calls.append("ok")
# Should not raise despite first hook failing
await _fire_hooks("subscription.activated", {})
assert calls == ["ok"]
async def test_failing_hook_is_logged(self, caplog):
@on_billing_event("subscription.activated")
async def bad_hook(event_type, data):
raise ValueError("test error")
import logging
with caplog.at_level(logging.ERROR):
await _fire_hooks("subscription.activated", {})
assert "bad_hook" in caplog.text
assert "test error" in caplog.text

View File

@@ -1,12 +1,15 @@
""" """
Route integration tests for Paddle billing endpoints. Route integration tests for Paddle billing endpoints.
External Paddle API calls mocked with respx.
"""
import json
import httpx Paddle SDK calls mocked via mock_paddle_client fixture.
"""
from unittest.mock import MagicMock
import pytest import pytest
import respx
CHECKOUT_METHOD = "POST" CHECKOUT_METHOD = "POST"
@@ -54,24 +57,16 @@ class TestCheckoutRoute:
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
@respx.mock
async def test_creates_checkout_session(self, auth_client, db, test_user):
respx.post("https://api.paddle.com/transactions").mock(
return_value=httpx.Response(200, json={
"data": {
"checkout": {
"url": "https://checkout.paddle.com/test_123"
}
}
})
)
async def test_creates_checkout_session(self, auth_client, db, test_user, mock_paddle_client):
mock_txn = MagicMock()
mock_txn.checkout.url = "https://checkout.paddle.com/test_123"
mock_paddle_client.transactions.create.return_value = mock_txn
response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False) response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False)
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
mock_paddle_client.transactions.create.assert_called_once()
async def test_invalid_plan_rejected(self, auth_client, db, test_user): async def test_invalid_plan_rejected(self, auth_client, db, test_user):
@@ -82,20 +77,13 @@ class TestCheckoutRoute:
@respx.mock async def test_api_error_propagates(self, auth_client, db, test_user, mock_paddle_client):
async def test_api_error_propagates(self, auth_client, db, test_user): mock_paddle_client.transactions.create.side_effect = Exception("API error")
with pytest.raises(Exception, match="API error"):
respx.post("https://api.paddle.com/transactions").mock(
return_value=httpx.Response(500, json={"error": "server error"})
)
with pytest.raises(httpx.HTTPStatusError):
await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}") await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}")
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
# Manage subscription / Portal # Manage subscription / Portal
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
@@ -110,24 +98,18 @@ class TestManageRoute:
response = await auth_client.post("/billing/manage", follow_redirects=False) response = await auth_client.post("/billing/manage", follow_redirects=False)
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
@respx.mock
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_test") async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
respx.get("https://api.paddle.com/subscriptions/sub_test").mock(
return_value=httpx.Response(200, json={
"data": {
"management_urls": {
"update_payment_method": "https://paddle.com/manage/test_123"
}
}
})
)
mock_sub = MagicMock()
mock_sub.management_urls.update_payment_method = "https://paddle.com/manage/test_123"
mock_paddle_client.subscriptions.get.return_value = mock_sub
response = await auth_client.post("/billing/manage", follow_redirects=False) response = await auth_client.post("/billing/manage", follow_redirects=False)
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
mock_paddle_client.subscriptions.get.assert_called_once_with("sub_test")
@@ -145,18 +127,14 @@ class TestCancelRoute:
response = await auth_client.post("/billing/cancel", follow_redirects=False) response = await auth_client.post("/billing/cancel", follow_redirects=False)
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
@respx.mock
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
respx.post("https://api.paddle.com/subscriptions/sub_test/cancel").mock(
return_value=httpx.Response(200, json={"data": {}})
)
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
response = await auth_client.post("/billing/cancel", follow_redirects=False) response = await auth_client.post("/billing/cancel", follow_redirects=False)
assert response.status_code in (302, 303, 307) assert response.status_code in (302, 303, 307)
mock_paddle_client.subscriptions.cancel.assert_called_once()
@@ -167,8 +145,9 @@ class TestCancelRoute:
# subscription_required decorator # subscription_required decorator
# ════════════════════════════════════════════════════════════ # ════════════════════════════════════════════════════════════
from beanflows.billing.routes import subscription_required from quart import Blueprint # noqa: E402
from quart import Blueprint
from beanflows.auth.routes import subscription_required # noqa: E402
test_bp = Blueprint("test", __name__) test_bp = Blueprint("test", __name__)

View File

@@ -5,13 +5,13 @@ Covers signature verification, event parsing, subscription lifecycle transitions
import json import json
import pytest import pytest
from conftest import make_webhook_payload, sign_payload
from hypothesis import HealthCheck, given from hypothesis import HealthCheck, given
from hypothesis import settings as h_settings from hypothesis import settings as h_settings
from hypothesis import strategies as st from hypothesis import strategies as st
from beanflows.billing.routes import get_subscription from beanflows.billing.routes import get_billing_customer, get_subscription
from conftest import make_webhook_payload, sign_payload
WEBHOOK_PATH = "/billing/webhook/paddle" WEBHOOK_PATH = "/billing/webhook/paddle"
@@ -72,18 +72,19 @@ class TestWebhookSignature:
async def test_modified_payload_rejected(self, client, db, test_user): async def test_modified_payload_rejected(self, client, db, test_user):
# Paddle SDK Verifier handles tamper detection internally.
# We test signature rejection via test_invalid_signature_rejected above.
# This test verifies the Verifier is actually called by sending
# a payload with an explicitly bad signature.
payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"])) payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"]))
payload_bytes = json.dumps(payload).encode() payload_bytes = json.dumps(payload).encode()
sig = sign_payload(payload_bytes)
tampered = payload_bytes + b"extra"
# Paddle/LemonSqueezy: HMAC signature verification fails before JSON parsing
response = await client.post( response = await client.post(
WEBHOOK_PATH, WEBHOOK_PATH,
data=tampered, data=payload_bytes,
headers={SIG_HEADER: sig, "Content-Type": "application/json"}, headers={SIG_HEADER: "invalid_signature", "Content-Type": "application/json"},
) )
assert response.status_code in (400, 401) assert response.status_code == 400
async def test_empty_payload_rejected(self, client, db): async def test_empty_payload_rejected(self, client, db):
@@ -105,7 +106,7 @@ class TestWebhookSignature:
class TestWebhookSubscriptionActivated: class TestWebhookSubscriptionActivated:
async def test_creates_subscription(self, client, db, test_user): async def test_creates_subscription_and_billing_customer(self, client, db, test_user):
payload = make_webhook_payload( payload = make_webhook_payload(
"subscription.activated", "subscription.activated",
user_id=str(test_user["id"]), user_id=str(test_user["id"]),
@@ -126,10 +127,14 @@ class TestWebhookSubscriptionActivated:
assert sub["plan"] == "starter" assert sub["plan"] == "starter"
assert sub["status"] == "active" assert sub["status"] == "active"
bc = await get_billing_customer(test_user["id"])
assert bc is not None
assert bc["provider_customer_id"] == "ctm_test123"
class TestWebhookSubscriptionUpdated: class TestWebhookSubscriptionUpdated:
async def test_updates_subscription_status(self, client, db, test_user, create_subscription): async def test_updates_subscription_status(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload( payload = make_webhook_payload(
"subscription.updated", "subscription.updated",
@@ -152,7 +157,7 @@ class TestWebhookSubscriptionUpdated:
class TestWebhookSubscriptionCanceled: class TestWebhookSubscriptionCanceled:
async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription): async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload( payload = make_webhook_payload(
"subscription.canceled", "subscription.canceled",
@@ -174,7 +179,7 @@ class TestWebhookSubscriptionCanceled:
class TestWebhookSubscriptionPastDue: class TestWebhookSubscriptionPastDue:
async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription): async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456") await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload( payload = make_webhook_payload(
"subscription.past_due", "subscription.past_due",
@@ -209,7 +214,7 @@ class TestWebhookSubscriptionPastDue:
]) ])
async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status): async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status):
if event_type != "subscription.activated": if event_type != "subscription.activated":
await create_subscription(test_user["id"], paddle_subscription_id="sub_test456") await create_subscription(test_user["id"], provider_subscription_id="sub_test456")
payload = make_webhook_payload(event_type, user_id=str(test_user["id"])) payload = make_webhook_payload(event_type, user_id=str(test_user["id"]))
payload_bytes = json.dumps(payload).encode() payload_bytes = json.dumps(payload).encode()

242
web/tests/test_roles.py Normal file
View File

@@ -0,0 +1,242 @@
"""
Tests for role-based access control: role_required decorator, grant/revoke/ensure_admin_role,
and admin route protection.
"""
import pytest
from quart import Blueprint
from beanflows.auth.routes import (
ensure_admin_role,
grant_role,
revoke_role,
role_required,
)
from beanflows import core
# ════════════════════════════════════════════════════════════
# grant_role / revoke_role
# ════════════════════════════════════════════════════════════
class TestGrantRole:
async def test_grants_role(self, db, test_user):
await grant_role(test_user["id"], "admin")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ?",
(test_user["id"],),
)
assert row is not None
assert row["role"] == "admin"
async def test_idempotent(self, db, test_user):
await grant_role(test_user["id"], "admin")
await grant_role(test_user["id"], "admin")
rows = await core.fetch_all(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert len(rows) == 1
class TestRevokeRole:
async def test_revokes_existing_role(self, db, test_user):
await grant_role(test_user["id"], "admin")
await revoke_role(test_user["id"], "admin")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
async def test_noop_for_missing_role(self, db, test_user):
# Should not raise
await revoke_role(test_user["id"], "nonexistent")
# ════════════════════════════════════════════════════════════
# ensure_admin_role
# ════════════════════════════════════════════════════════════
class TestEnsureAdminRole:
async def test_grants_admin_for_listed_email(self, db, test_user):
core.config.ADMIN_EMAILS = ["test@example.com"]
try:
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is not None
finally:
core.config.ADMIN_EMAILS = []
async def test_skips_for_unlisted_email(self, db, test_user):
core.config.ADMIN_EMAILS = ["boss@example.com"]
try:
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
finally:
core.config.ADMIN_EMAILS = []
async def test_empty_admin_emails_grants_nothing(self, db, test_user):
core.config.ADMIN_EMAILS = []
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
async def test_case_insensitive_matching(self, db, test_user):
core.config.ADMIN_EMAILS = ["test@example.com"]
try:
await ensure_admin_role(test_user["id"], "Test@Example.COM")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is not None
finally:
core.config.ADMIN_EMAILS = []
# ════════════════════════════════════════════════════════════
# role_required decorator
# ════════════════════════════════════════════════════════════
role_test_bp = Blueprint("role_test", __name__)
@role_test_bp.route("/admin-only")
@role_required("admin")
async def admin_only_route():
return "admin-ok", 200
@role_test_bp.route("/multi-role")
@role_required("admin", "editor")
async def multi_role_route():
return "multi-ok", 200
class TestRoleRequired:
@pytest.fixture
async def role_app(self, app):
app.register_blueprint(role_test_bp)
return app
@pytest.fixture
async def role_client(self, role_app):
async with role_app.test_client() as c:
yield c
async def test_redirects_unauthenticated(self, role_client, db):
response = await role_client.get("/admin-only", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_rejects_user_without_role(self, role_client, db, test_user):
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/admin-only", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_allows_user_with_matching_role(self, role_client, db, test_user):
await grant_role(test_user["id"], "admin")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/admin-only")
assert response.status_code == 200
async def test_multi_role_allows_any_match(self, role_client, db, test_user):
await grant_role(test_user["id"], "editor")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/multi-role")
assert response.status_code == 200
async def test_multi_role_rejects_none(self, role_client, db, test_user):
await grant_role(test_user["id"], "viewer")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/multi-role", follow_redirects=False)
assert response.status_code in (302, 303, 307)
# ════════════════════════════════════════════════════════════
# Admin route protection
# ════════════════════════════════════════════════════════════
class TestAdminRouteProtection:
async def test_admin_index_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_admin_index_accessible_with_admin_role(self, admin_client, db):
response = await admin_client.get("/admin/")
assert response.status_code == 200
async def test_admin_users_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/users", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_admin_tasks_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/tasks", follow_redirects=False)
assert response.status_code in (302, 303, 307)
# ════════════════════════════════════════════════════════════
# Impersonation
# ════════════════════════════════════════════════════════════
class TestImpersonation:
async def test_impersonate_stores_admin_id(self, admin_client, db, test_user):
"""Impersonating stores admin's user_id in session['admin_impersonating']."""
# Create a second user to impersonate
now = "2025-01-01T00:00:00"
other_id = await core.execute(
"INSERT INTO users (email, name, created_at) VALUES (?, ?, ?)",
("other@example.com", "Other", now),
)
async with admin_client.session_transaction() as sess:
sess["csrf_token"] = "test_csrf"
response = await admin_client.post(
f"/admin/users/{other_id}/impersonate",
form={"csrf_token": "test_csrf"},
follow_redirects=False,
)
assert response.status_code in (302, 303, 307)
async with admin_client.session_transaction() as sess:
assert sess["user_id"] == other_id
assert sess["admin_impersonating"] == test_user["id"]
async def test_stop_impersonating_restores_admin(self, app, db, test_user, grant_role):
"""Stopping impersonation restores the admin's user_id."""
await grant_role(test_user["id"], "admin")
async with app.test_client() as c:
async with c.session_transaction() as sess:
sess["user_id"] = 999 # impersonated user
sess["admin_impersonating"] = test_user["id"]
sess["csrf_token"] = "test_csrf"
response = await c.post(
"/admin/stop-impersonating",
form={"csrf_token": "test_csrf"},
follow_redirects=False,
)
assert response.status_code in (302, 303, 307)
async with c.session_transaction() as sess:
assert sess["user_id"] == test_user["id"]
assert "admin_impersonating" not in sess