Merge branch 'copier-update'

This commit is contained in:
Deeman
2026-02-19 22:46:41 +01:00
22 changed files with 1360 additions and 401 deletions

View File

@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
_commit: v0.3.0
_commit: v0.4.0
_src_path: git@gitlab.com:deemanone/materia_saas_boilerplate.master.git
author_email: hendrik@beanflows.coffee
author_name: Hendrik Deeman

73
web/CHANGELOG.md Normal file
View File

@@ -0,0 +1,73 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
### Changed
- **Role-based access control**: `user_roles` table with `role_required()` decorator replaces password-based admin auth
- **Admin is a real user**: admins authenticate via magic links; `ADMIN_EMAILS` env var auto-grants admin role on login
- **Separated billing entities**: `billing_customers` table holds payment provider identity; `subscriptions` table holds only subscription state
- **Multiple subscriptions per user**: dropped UNIQUE constraint on `subscriptions.user_id`; `upsert_subscription` finds by `provider_subscription_id`
### Added
- Simple A/B testing with `@ab_test` decorator and optional Umami `data-tag` integration (`UMAMI_SCRIPT_URL` / `UMAMI_WEBSITE_ID` env vars)
- `user_roles` table and `grant_role()` / `revoke_role()` / `ensure_admin_role()` functions
- `billing_customers` table and `upsert_billing_customer()` / `get_billing_customer()` functions
- `role_required(*roles)` decorator in auth
- `is_admin` template context variable
- Migration `0001_roles_and_billing_customers.py` for existing databases
### Removed
- `ADMIN_PASSWORD` env var and password-based admin login
- `provider_customer_id` column from `subscriptions` table
- `admin/templates/admin/login.html`
### Changed
- **Provider-agnostic schema**: generic `provider_customer_id` / `provider_subscription_id` columns replace provider-prefixed names (`stripe_customer_id`, `paddle_customer_id`, `lemonsqueezy_customer_id`) — eliminates all Jinja conditionals from schema, SQL helpers, and route code
- **Consolidated `subscription_required` decorator**: single implementation in `auth/routes.py` supporting both plan and status checks, reads from eager-loaded `g.subscription` (zero extra queries)
- **Eager-loaded `g.subscription`**: `load_user` in `app.py` now fetches user + subscription in a single JOIN; available in all routes and templates via `g.subscription`
### Added
- `transactions` table for recording payment/refund events with idempotent `record_transaction()` helper
- Billing event hook system: `on_billing_event()` decorator and `_fire_hooks()` for domain code to react to subscription changes; errors are logged and never cause webhook 500s
### Removed
- Duplicate `subscription_required` decorator from `billing/routes.py` (consolidated in `auth/routes.py`)
- `get_user_with_subscription()` from `auth/routes.py` (replaced by eager-loaded `g.subscription`)
### Changed
- **Email SDK migration**: replaced raw httpx calls with official `resend` SDK in `core.py`
- Added `from_addr` parameter to `send_email()` for multi-address support
- Added `EMAIL_ADDRESSES` dict for named sender addresses (transactional, etc.)
- **Paddle SDK migration**: replaced raw httpx calls with official `paddle-python-sdk` in `billing/routes.py`
- Checkout, manage, cancel routes now use typed SDK methods (`PaddleClient`, `CreateTransaction`)
- Webhook verification uses SDK's `Verifier` instead of hand-rolled HMAC
- Added `PADDLE_ENVIRONMENT` config for sandbox/production toggling
- Added `_paddle_client()` helper factory
- **Dependencies**: `resend` replaces `httpx` for email; `paddle-python-sdk` replaces `httpx` for Paddle billing; `httpx` now only included for LemonSqueezy projects
- Worker `send_email` task handler now passes through `from_addr`
### Added
- `scripts/setup_paddle.py` — CLI script to create Paddle products/prices programmatically (Paddle projects only)
### Changed
- **Pico CSS → Tailwind CSS v4** — full design system migration across all templates
- Standalone Tailwind CLI binary (no Node.js) with `make css-build` / `make css-watch`
- Brand theme with component classes (`.btn`, `.card`, `.form-input`, `.table`, `.badge`, `.flash`, etc.)
- Self-hosted Commit Mono font for monospace data display
- Docker multi-stage build: CSS compiled in dedicated stage before Python build
### Removed
- Pico CSS CDN dependency
- `custom.css` (replaced by Tailwind `input.css` with `@layer components`)
- JetBrains Mono font (replaced by self-hosted Commit Mono)
### Fixed
- Admin template collision: namespaced admin templates under `admin/` subdirectory to prevent Quart's template loader from resolving auth's `login.html` or dashboard's `index.html` instead of admin's
- Admin user detail: `stripe_customer_id` hardcoded regardless of payment provider — now uses provider-aware Copier conditional (Stripe/Paddle/LemonSqueezy)
### Added
- Initial project scaffolded from quart_saas_boilerplate

View File

@@ -1,3 +1,13 @@
# CSS build stage (Tailwind standalone CLI, no Node.js)
FROM debian:bookworm-slim AS css-build
ADD https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 /usr/local/bin/tailwindcss
RUN chmod +x /usr/local/bin/tailwindcss
WORKDIR /app
COPY src/ ./src/
RUN tailwindcss -i ./src/beanflows/static/css/input.css \
-o ./src/beanflows/static/css/output.css --minify
# Build stage
FROM python:3.12-slim AS build
COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/
@@ -15,6 +25,7 @@ RUN useradd -m -u 1000 appuser
WORKDIR /app
RUN mkdir -p /app/data && chown -R appuser:appuser /app
COPY --from=build --chown=appuser:appuser /app .
COPY --from=css-build /app/src/beanflows/static/css/output.css ./src/beanflows/static/css/output.css
USER appuser
ENV PYTHONUNBUFFERED=1
ENV DATABASE_PATH=/app/data/app.db

View File

@@ -21,16 +21,16 @@ services:
command: replicate -config /etc/litestream.yml
volumes:
- app-data:/app/data
- ./beanflows/litestream.yml:/etc/litestream.yml:ro
- ./litestream.yml:/etc/litestream.yml:ro
# ── Blue slot ─────────────────────────────────────────────
blue-app:
profiles: ["blue"]
build:
context: ./beanflows
context: .
restart: unless-stopped
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:
@@ -47,10 +47,10 @@ services:
blue-worker:
profiles: ["blue"]
build:
context: ./beanflows
context: .
restart: unless-stopped
command: python -m beanflows.worker
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:
@@ -61,10 +61,10 @@ services:
blue-scheduler:
profiles: ["blue"]
build:
context: ./beanflows
context: .
restart: unless-stopped
command: python -m beanflows.worker scheduler
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:
@@ -77,9 +77,9 @@ services:
green-app:
profiles: ["green"]
build:
context: ./beanflows
context: .
restart: unless-stopped
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:
@@ -96,10 +96,10 @@ services:
green-worker:
profiles: ["green"]
build:
context: ./beanflows
context: .
restart: unless-stopped
command: python -m beanflows.worker
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:
@@ -110,10 +110,10 @@ services:
green-scheduler:
profiles: ["green"]
build:
context: ./beanflows
context: .
restart: unless-stopped
command: python -m beanflows.worker scheduler
env_file: ./beanflows/.env
env_file: ./.env
environment:
- DATABASE_PATH=/app/data/app.db
volumes:

View File

@@ -12,8 +12,9 @@ dependencies = [
"aiosqlite>=0.19.0",
"duckdb>=1.0.0",
"httpx>=0.27.0",
"resend>=2.22.0",
"python-dotenv>=1.0.0",
"paddle-python-sdk>=1.13.0",
"itsdangerous>=2.1.0",
"jinja2>=3.1.0",
"hypercorn>=0.17.0",

View File

@@ -52,14 +52,40 @@ def create_app() -> Quart:
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
return response
# Load current user before each request
# Load current user + subscription + roles before each request
@app.before_request
async def load_user():
g.user = None
g.subscription = None
user_id = session.get("user_id")
if user_id:
from .auth.routes import get_user_by_id
g.user = await get_user_by_id(user_id)
from .core import fetch_one as _fetch_one
row = await _fetch_one(
"""SELECT u.*,
bc.provider_customer_id,
(SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv,
s.id AS sub_id, s.plan, s.status AS sub_status,
s.provider_subscription_id, s.current_period_end
FROM users u
LEFT JOIN billing_customers bc ON bc.user_id = u.id
LEFT JOIN subscriptions s ON s.id = (
SELECT id FROM subscriptions
WHERE user_id = u.id
ORDER BY created_at DESC LIMIT 1
)
WHERE u.id = ? AND u.deleted_at IS NULL""",
(user_id,),
)
if row:
g.user = dict(row)
g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else []
if row["sub_id"]:
g.subscription = {
"id": row["sub_id"], "plan": row["plan"],
"status": row["sub_status"],
"provider_subscription_id": row["provider_subscription_id"],
"current_period_end": row["current_period_end"],
}
# Template context globals
@app.context_processor
@@ -68,8 +94,12 @@ def create_app() -> Quart:
return {
"config": config,
"user": g.get("user"),
"subscription": g.get("subscription"),
"is_admin": "admin" in (g.get("user") or {}).get("roles", []),
"now": datetime.utcnow(),
"csrf_token": get_csrf_token,
"ab_variant": getattr(g, "ab_variant", None),
"ab_tag": getattr(g, "ab_tag", None),
}
# Health check

View File

@@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None:
)
async def get_user_with_subscription(user_id: int) -> dict | None:
"""Get user with their active subscription info."""
return await fetch_one(
"""
SELECT u.*, s.plan, s.status as sub_status, s.current_period_end
FROM users u
LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active'
WHERE u.id = ? AND u.deleted_at IS NULL
""",
(user_id,)
)
# =============================================================================
# Decorators
# =============================================================================
@@ -116,8 +103,53 @@ def login_required(f):
return decorated
def subscription_required(plans: list[str] = None):
"""Require active subscription, optionally of specific plan(s)."""
def role_required(*roles):
"""Require user to have at least one of the given roles."""
def decorator(f):
@wraps(f)
async def decorated(*args, **kwargs):
if not g.get("user"):
await flash("Please sign in to continue.", "warning")
return redirect(url_for("auth.login", next=request.path))
user_roles = g.user.get("roles", [])
if not any(r in user_roles for r in roles):
await flash("You don't have permission to access that page.", "error")
return redirect(url_for("dashboard.index"))
return await f(*args, **kwargs)
return decorated
return decorator
async def grant_role(user_id: int, role: str) -> None:
"""Grant a role to a user (idempotent)."""
await execute(
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
(user_id, role),
)
async def revoke_role(user_id: int, role: str) -> None:
"""Revoke a role from a user."""
await execute(
"DELETE FROM user_roles WHERE user_id = ? AND role = ?",
(user_id, role),
)
async def ensure_admin_role(user_id: int, email: str) -> None:
"""Grant admin role if email is in ADMIN_EMAILS."""
if email.lower() in config.ADMIN_EMAILS:
await grant_role(user_id, "admin")
def subscription_required(
plans: list[str] = None,
allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"),
):
"""Require active subscription, optionally of specific plan(s) and/or statuses.
Reads from g.subscription (eager-loaded in load_user) — zero extra queries.
"""
def decorator(f):
@wraps(f)
async def decorated(*args, **kwargs):
@@ -125,12 +157,12 @@ def subscription_required(plans: list[str] = None):
await flash("Please sign in to continue.", "warning")
return redirect(url_for("auth.login"))
user = await get_user_with_subscription(g.user["id"])
if not user or not user.get("plan"):
sub = g.get("subscription")
if not sub or sub["status"] not in allowed:
await flash("Please subscribe to access this feature.", "warning")
return redirect(url_for("billing.pricing"))
if plans and user["plan"] not in plans:
if plans and sub["plan"] not in plans:
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
return redirect(url_for("billing.pricing"))
@@ -246,6 +278,9 @@ async def verify():
session.permanent = True
session["user_id"] = token_data["user_id"]
# Auto-grant admin role if email is in ADMIN_EMAILS
await ensure_admin_role(token_data["user_id"], token_data["email"])
await flash("Successfully signed in!", "success")
# Redirect to intended page or dashboard
@@ -286,6 +321,9 @@ async def dev_login():
session.permanent = True
session["user_id"] = user_id
# Auto-grant admin role if email is in ADMIN_EMAILS
await ensure_admin_role(user_id, email)
await flash(f"Dev login as {email}", "success")
return redirect(url_for("dashboard.index"))

View File

@@ -2,16 +2,19 @@
Core infrastructure: database, config, email, and shared utilities.
"""
import os
import random
import secrets
import hashlib
import hmac
import aiosqlite
import resend
import httpx
from pathlib import Path
from functools import wraps
from datetime import datetime, timedelta
from contextvars import ContextVar
from quart import request, session, g
from quart import g, make_response, request, session
from dotenv import load_dotenv
# web/.env is three levels up from web/src/beanflows/core.py
@@ -36,14 +39,22 @@ class Config:
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
PADDLE_PRICES: dict = {
"starter": os.getenv("PADDLE_PRICE_STARTER", ""),
"pro": os.getenv("PADDLE_PRICE_PRO", ""),
}
UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "")
UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "")
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
ADMIN_EMAILS: list[str] = [
e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip()
]
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
@@ -153,25 +164,32 @@ class transaction:
# Email
# =============================================================================
async def send_email(to: str, subject: str, html: str, text: str = None) -> bool:
"""Send email via Resend API."""
EMAIL_ADDRESSES = {
"transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>",
}
async def send_email(
to: str, subject: str, html: str, text: str = None, from_addr: str = None
) -> bool:
"""Send email via Resend SDK."""
if not config.RESEND_API_KEY:
print(f"[EMAIL] Would send to {to}: {subject}")
return True
async with httpx.AsyncClient() as client:
response = await client.post(
"https://api.resend.com/emails",
headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"},
json={
"from": config.EMAIL_FROM,
resend.api_key = config.RESEND_API_KEY
try:
resend.Emails.send({
"from": from_addr or config.EMAIL_FROM,
"to": to,
"subject": subject,
"html": html,
"text": text or html,
},
)
return response.status_code == 200
})
return True
except Exception as e:
print(f"[EMAIL] Error sending to {to}: {e}")
return False
# =============================================================================
# CSRF Protection
@@ -294,13 +312,11 @@ def setup_request_id(app):
# Webhook Signature Verification
# =============================================================================
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
"""Verify HMAC-SHA256 webhook signature."""
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
return hmac.compare_digest(signature, expected)
# =============================================================================
# Soft Delete Helpers
# =============================================================================
@@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int:
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
(cutoff,)
)
# =============================================================================
# A/B Testing
# =============================================================================
def ab_test(experiment: str, variants: tuple = ("control", "treatment")):
"""Assign visitor to an A/B test variant via cookie, tag Umami pageviews."""
def decorator(f):
@wraps(f)
async def wrapper(*args, **kwargs):
cookie_key = f"ab_{experiment}"
assigned = request.cookies.get(cookie_key)
if assigned not in variants:
assigned = random.choice(variants)
g.ab_variant = assigned
g.ab_tag = f"{experiment}-{assigned}"
response = await make_response(await f(*args, **kwargs))
response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60)
return response
return wrapper
return decorator

View File

@@ -1,47 +1,91 @@
"""
Simple migration runner. Runs schema.sql against the database.
"""
import sqlite3
from pathlib import Path
import os
import sys
Sequential migration runner.
Replays all migrations in order. All databases — fresh and existing —
go through the same path. No schema.sql fast-path.
- Scans versions/ for NNNN_*.py files and runs unapplied ones in order
- Each migration has an up(conn) function receiving an uncommitted connection
- All pending migrations share a single transaction (batch atomicity)
"""
import importlib
import os
import re
import sqlite3
import sys
from pathlib import Path
# Add parent to path for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from dotenv import load_dotenv
load_dotenv()
VERSIONS_DIR = Path(__file__).parent / "versions"
VERSION_RE = re.compile(r"^(\d{4})_.+\.py$")
def migrate():
"""Run migrations."""
# Get database path from env or default
# Derived from the package path: …/src/<slug>/migrations/migrate.py
_PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject"
def _discover_versions():
"""Return sorted list of version file stems."""
if not VERSIONS_DIR.is_dir():
return []
versions = []
for f in sorted(VERSIONS_DIR.iterdir()):
if VERSION_RE.match(f.name):
versions.append(f.stem)
return versions
def migrate(db_path=None):
if db_path is None:
db_path = os.getenv("DATABASE_PATH", "data/app.db")
# Ensure directory exists
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
# Read schema
schema_path = Path(__file__).parent / "schema.sql"
schema = schema_path.read_text()
# Connect and execute
conn = sqlite3.connect(db_path)
# Enable WAL mode
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA foreign_keys=ON")
# Run schema
conn.executescript(schema)
# Ensure tracking table exists before anything else
conn.execute("""
CREATE TABLE IF NOT EXISTS _migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
)
""")
conn.commit()
print(f"✓ Migrations complete: {db_path}")
versions = _discover_versions()
applied = {
row[0]
for row in conn.execute("SELECT name FROM _migrations").fetchall()
}
pending = [v for v in versions if v not in applied]
# Show tables
if pending:
for name in pending:
print(f" Applying {name}...")
mod = importlib.import_module(
f"{_PACKAGE}.migrations.versions.{name}"
)
mod.up(conn)
conn.execute(
"INSERT INTO _migrations (name) VALUES (?)", (name,)
)
conn.commit()
print(f"✓ Applied {len(pending)} migration(s): {db_path}")
else:
print(f"✓ All migrations already applied: {db_path}")
# Show tables (excluding internal sqlite/fts tables)
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
"SELECT name FROM sqlite_master WHERE type='table'"
" AND name NOT LIKE 'sqlite_%'"
" ORDER BY name"
)
tables = [row[0] for row in cursor.fetchall()]
print(f" Tables: {', '.join(tables)}")

View File

@@ -1,6 +1,13 @@
-- BeanFlows Database Schema
-- Run with: python -m beanflows.migrations.migrate
-- Migration tracking
CREATE TABLE IF NOT EXISTS _migrations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT UNIQUE NOT NULL,
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
);
-- Users
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
@@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens (
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
-- User Roles
CREATE TABLE IF NOT EXISTS user_roles (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
role TEXT NOT NULL,
granted_at TEXT NOT NULL DEFAULT (datetime('now')),
UNIQUE(user_id, role)
);
CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id);
CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role);
-- Billing Customers (payment provider identity, separate from subscriptions)
CREATE TABLE IF NOT EXISTS billing_customers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
provider_customer_id TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now'))
);
CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id);
CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id);
-- Subscriptions
CREATE TABLE IF NOT EXISTS subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
user_id INTEGER NOT NULL REFERENCES users(id),
plan TEXT NOT NULL DEFAULT 'free',
status TEXT NOT NULL DEFAULT 'free',
paddle_customer_id TEXT,
paddle_subscription_id TEXT,
provider_subscription_id TEXT,
current_period_end TEXT,
created_at TEXT NOT NULL,
updated_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id);
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id);
-- Transactions
CREATE TABLE IF NOT EXISTS transactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
subscription_id INTEGER REFERENCES subscriptions(id),
provider_transaction_id TEXT UNIQUE,
type TEXT NOT NULL DEFAULT 'payment',
amount_cents INTEGER,
currency TEXT DEFAULT 'USD',
status TEXT NOT NULL DEFAULT 'pending',
metadata TEXT,
created_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id);
CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id);
-- API Keys
CREATE TABLE IF NOT EXISTS api_keys (
@@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks (
);
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
-- Items (example domain entity - replace with your domain)
CREATE TABLE IF NOT EXISTS items (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL REFERENCES users(id),
name TEXT NOT NULL,
data TEXT,
created_at TEXT NOT NULL,
updated_at TEXT,
deleted_at TEXT
);
CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id);
CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at);
-- Full-text search for items (optional)
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
name,
data,
content='items',
content_rowid='id'
);
-- FTS triggers
CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
END;
CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
END;
CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
END;

View File

View File

@@ -0,0 +1,92 @@
"""
Create Paddle products and prices for BeanFlows.
Run once per environment (sandbox, then production).
Prints resulting price IDs as a .env snippet.
Usage:
uv run python -m beanflows.scripts.setup_paddle
"""
import os
import sys
from dotenv import load_dotenv
from paddle_billing import Client as PaddleClient
from paddle_billing import Environment, Options
from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory
from paddle_billing.Resources.Prices.Operations import CreatePrice
from paddle_billing.Resources.Products.Operations import CreateProduct
load_dotenv()
PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "")
PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
if not PADDLE_API_KEY:
print("ERROR: Set PADDLE_API_KEY in .env first")
sys.exit(1)
PRODUCTS = [
# Subscriptions
{
"name": "Starter",
"env_key": "PADDLE_PRICE_STARTER",
"price": 900,
"currency": CurrencyCode.USD,
"interval": "month",
"type": "subscription",
},
{
"name": "Pro",
"env_key": "PADDLE_PRICE_PRO",
"price": 2900,
"currency": CurrencyCode.USD,
"interval": "month",
"type": "subscription",
},
]
def main():
env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION
paddle = PaddleClient(PADDLE_API_KEY, options=Options(env))
print(f"Creating products in {PADDLE_ENVIRONMENT}...\n")
env_lines = []
for spec in PRODUCTS:
# Create product
product = paddle.products.create(CreateProduct(
name=spec["name"],
tax_category=TaxCategory.Standard,
))
print(f" Product: {spec['name']} -> {product.id}")
# Create price
price_kwargs = {
"description": spec["name"],
"product_id": product.id,
"unit_price": Money(str(spec["price"]), spec["currency"]),
}
if spec["type"] == "subscription":
from paddle_billing.Entities.Shared import TimePeriod
price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1)
price = paddle.prices.create(CreatePrice(**price_kwargs))
print(f" Price: {spec['env_key']} = {price.id}")
env_lines.append(f"{spec['env_key']}={price.id}")
print("\n# --- .env snippet ---")
for line in env_lines:
print(line)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,90 @@
This Font Software is licensed under the SIL Open Font License, Version 1.1.
This license is copied below, and is also available with a FAQ at:
http://scripts.sil.org/OFL
-----------------------------------------------------------
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
-----------------------------------------------------------
PREAMBLE
The goals of the Open Font License (OFL) are to stimulate worldwide
development of collaborative font projects, to support the font creation
efforts of academic and linguistic communities, and to provide a free and
open framework in which fonts may be shared and improved in partnership
with others.
The OFL allows the licensed fonts to be used, studied, modified and
redistributed freely as long as they are not sold by themselves. The
fonts, including any derivative works, can be bundled, embedded,
redistributed and/or sold with any software provided that any reserved
names are not used by derivative works. The fonts and derivatives,
however, cannot be released under any other type of license. The
requirement for fonts to remain under this license does not apply
to any document created using the fonts or their derivatives.
DEFINITIONS
"Font Software" refers to the set of files released by the Copyright
Holder(s) under this license and clearly marked as such. This may
include source files, build scripts and documentation.
"Reserved Font Name" refers to any names specified as such after the
copyright statement(s).
"Original Version" refers to the collection of Font Software components as
distributed by the Copyright Holder(s).
"Modified Version" refers to any derivative made by adding to, deleting,
or substituting -- in part or in whole -- any of the components of the
Original Version, by changing formats or by porting the Font Software to a
new environment.
"Author" refers to any designer, engineer, programmer, technical
writer or other person who contributed to the Font Software.
PERMISSION & CONDITIONS
Permission is hereby granted, free of charge, to any person obtaining
a copy of the Font Software, to use, study, copy, merge, embed, modify,
redistribute, and sell modified and unmodified copies of the Font
Software, subject to the following conditions:
1) Neither the Font Software nor any of its individual components,
in Original or Modified Versions, may be sold by itself.
2) Original or Modified Versions of the Font Software may be bundled,
redistributed and/or sold with any software, provided that each copy
contains the above copyright notice and this license. These can be
included either as stand-alone text files, human-readable headers or
in the appropriate machine-readable metadata fields within text or
binary files as long as those fields can be easily viewed by the user.
3) No Modified Version of the Font Software may use the Reserved Font
Name(s) unless explicit written permission is granted by the corresponding
Copyright Holder. This restriction only applies to the primary font name as
presented to the users.
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
Software shall not be used to promote, endorse or advertise any
Modified Version, except to acknowledge the contribution(s) of the
Copyright Holder(s) and the Author(s) or with their explicit written
permission.
5) The Font Software, modified or unmodified, in part or in whole,
must be distributed entirely under this license, and must not be
distributed under any other license. The requirement for fonts to
remain under this license does not apply to any document created
using the Font Software.
TERMINATION
This license becomes null and void if any of the above conditions are
not met.
DISCLAIMER
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
OTHER DEALINGS IN THE FONT SOFTWARE.

View File

@@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email
HANDLERS: dict[str, callable] = {}
def _email_wrap(body: str) -> str:
"""Wrap email body in a branded layout with inline CSS."""
return f"""\
<!DOCTYPE html>
<html lang="en">
<head><meta charset="utf-8"></head>
<body style="margin:0;padding:0;background-color:#F8FAFC;font-family:'Inter',Helvetica,Arial,sans-serif;">
<table width="100%" cellpadding="0" cellspacing="0" style="background-color:#F8FAFC;padding:40px 0;">
<tr><td align="center">
<table width="480" cellpadding="0" cellspacing="0" style="background-color:#FFFFFF;border-radius:8px;border:1px solid #E2E8F0;overflow:hidden;">
<!-- Header -->
<tr><td style="background-color:#0F172A;padding:24px 32px;">
<span style="color:#FFFFFF;font-size:18px;font-weight:700;letter-spacing:-0.02em;">{config.APP_NAME}</span>
</td></tr>
<!-- Body -->
<tr><td style="padding:32px;color:#475569;font-size:15px;line-height:1.6;">
{body}
</td></tr>
<!-- Footer -->
<tr><td style="padding:20px 32px;border-top:1px solid #E2E8F0;text-align:center;">
<span style="color:#94A3B8;font-size:12px;">&copy; {config.APP_NAME} &middot; You received this because you have an account.</span>
</td></tr>
</table>
</td></tr>
</table>
</body>
</html>"""
def _email_button(url: str, label: str) -> str:
"""Render a branded CTA button for email."""
return (
f'<table cellpadding="0" cellspacing="0" style="margin:24px 0;">'
f'<tr><td style="background-color:#3B82F6;border-radius:6px;text-align:center;">'
f'<a href="{url}" style="display:inline-block;padding:12px 28px;'
f'color:#FFFFFF;font-size:15px;font-weight:600;text-decoration:none;">'
f'{label}</a></td></tr></table>'
)
def task(name: str):
"""Decorator to register a task handler."""
def decorator(f):
@@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None:
subject=payload["subject"],
html=payload["html"],
text=payload.get("text"),
from_addr=payload.get("from_addr"),
)
@@ -107,34 +148,36 @@ async def handle_send_magic_link(payload: dict) -> None:
"""Send magic link email."""
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
html = f"""
<h2>Sign in to {config.APP_NAME}</h2>
<p>Click the link below to sign in:</p>
<p><a href="{link}">{link}</a></p>
<p>This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>
<p>If you didn't request this, you can safely ignore this email.</p>
"""
body = (
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Sign in to {config.APP_NAME}</h2>'
f"<p>Click the button below to sign in. This link expires in "
f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>"
f"{_email_button(link, 'Sign In')}"
f'<p style="font-size:13px;color:#94A3B8;">If the button doesn\'t work, copy and paste this URL into your browser:</p>'
f'<p style="font-size:13px;color:#94A3B8;word-break:break-all;">{link}</p>'
f'<p style="font-size:13px;color:#94A3B8;">If you didn\'t request this, you can safely ignore this email.</p>'
)
await send_email(
to=payload["email"],
subject=f"Sign in to {config.APP_NAME}",
html=html,
html=_email_wrap(body),
)
@task("send_welcome")
async def handle_send_welcome(payload: dict) -> None:
"""Send welcome email to new user."""
html = f"""
<h2>Welcome to {config.APP_NAME}!</h2>
<p>Thanks for signing up. We're excited to have you.</p>
<p><a href="{config.BASE_URL}/dashboard">Go to your dashboard</a></p>
"""
body = (
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Welcome to {config.APP_NAME}!</h2>'
f"<p>Thanks for signing up. We're excited to have you.</p>"
f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}'
)
await send_email(
to=payload["email"],
subject=f"Welcome to {config.APP_NAME}",
html=html,
html=_email_wrap(body),
)

View File

@@ -10,9 +10,11 @@ from unittest.mock import AsyncMock, patch
import aiosqlite
import pytest
from beanflows import analytics, core
from beanflows import core
from beanflows.app import create_app
SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql"
@@ -44,9 +46,7 @@ async def db():
async def app(db):
"""Quart app with DB already initialized (init_db/close_db patched to no-op)."""
with patch.object(core, "init_db", new_callable=AsyncMock), \
patch.object(core, "close_db", new_callable=AsyncMock), \
patch.object(analytics, "open_analytics_db"), \
patch.object(analytics, "close_analytics_db"):
patch.object(core, "close_db", new_callable=AsyncMock):
application = create_app()
application.config["TESTING"] = True
yield application
@@ -92,22 +92,17 @@ def create_subscription(db):
user_id: int,
plan: str = "pro",
status: str = "active",
paddle_customer_id: str = "ctm_test123",
paddle_subscription_id: str = "sub_test456",
provider_subscription_id: str = "sub_test456",
current_period_end: str = "2025-03-01T00:00:00Z",
) -> int:
now = datetime.utcnow().isoformat()
async with db.execute(
"""INSERT INTO subscriptions
(user_id, plan, status, paddle_customer_id,
paddle_subscription_id, current_period_end, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(user_id, plan, status, paddle_customer_id, paddle_subscription_id,
(user_id, plan, status,
provider_subscription_id, current_period_end, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
(user_id, plan, status, provider_subscription_id,
current_period_end, now, now),
) as cursor:
sub_id = cursor.lastrowid
await db.commit()
@@ -115,6 +110,48 @@ def create_subscription(db):
return _create
# ── Billing Customers ───────────────────────────────────────
@pytest.fixture
def create_billing_customer(db):
"""Factory: create a billing_customers row for a user."""
async def _create(user_id: int, provider_customer_id: str = "cust_test123") -> int:
async with db.execute(
"""INSERT INTO billing_customers (user_id, provider_customer_id)
VALUES (?, ?)
ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""",
(user_id, provider_customer_id),
) as cursor:
row_id = cursor.lastrowid
await db.commit()
return row_id
return _create
# ── Roles ───────────────────────────────────────────────────
@pytest.fixture
def grant_role(db):
"""Factory: grant a role to a user."""
async def _grant(user_id: int, role: str) -> None:
await db.execute(
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
(user_id, role),
)
await db.commit()
return _grant
@pytest.fixture
async def admin_client(app, test_user, grant_role):
"""Test client with admin role and session['user_id'] pre-set."""
await grant_role(test_user["id"], "admin")
async with app.test_client() as c:
async with c.session_transaction() as sess:
sess["user_id"] = test_user["id"]
yield c
# ── Config ───────────────────────────────────────────────────
@pytest.fixture(autouse=True)
@@ -127,6 +164,7 @@ def patch_config():
"PADDLE_API_KEY": "test_api_key_123",
"PADDLE_WEBHOOK_SECRET": "whsec_test_secret",
"PADDLE_ENVIRONMENT": "sandbox",
"PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"},
"BASE_URL": "http://localhost:5000",
@@ -147,6 +185,32 @@ def patch_config():
# ── Webhook helpers ──────────────────────────────────────────
@pytest.fixture(autouse=True)
def mock_paddle_verifier(monkeypatch):
"""Mock Paddle's webhook Verifier to accept test payloads."""
def mock_verify(self, payload, secret, signature):
if not signature or signature == "invalid_signature":
raise ValueError("Invalid signature")
monkeypatch.setattr(
"paddle_billing.Notifications.Verifier.verify",
mock_verify,
)
@pytest.fixture
def mock_paddle_client(monkeypatch):
"""Mock _paddle_client() to return a fake PaddleClient."""
from unittest.mock import MagicMock
mock_client = MagicMock()
monkeypatch.setattr(
"beanflows.billing.routes._paddle_client",
lambda: mock_client,
)
return mock_client
def make_webhook_payload(
event_type: str,
subscription_id: str = "sub_test456",
@@ -172,76 +236,8 @@ def make_webhook_payload(
}
def sign_payload(payload_bytes: bytes, secret: str = "whsec_test_secret") -> str:
"""Compute HMAC-SHA256 signature for a webhook payload."""
return hmac.new(secret.encode(), payload_bytes, hashlib.sha256).hexdigest()
# ── Analytics mock data ──────────────────────────────────────
MOCK_TIME_SERIES = [
{"market_year": 2018, "Production": 165000, "Exports": 115000, "Imports": 105000,
"Ending_Stocks": 33000, "Total_Distribution": 160000},
{"market_year": 2019, "Production": 168000, "Exports": 118000, "Imports": 108000,
"Ending_Stocks": 34000, "Total_Distribution": 163000},
{"market_year": 2020, "Production": 170000, "Exports": 120000, "Imports": 110000,
"Ending_Stocks": 35000, "Total_Distribution": 165000},
{"market_year": 2021, "Production": 175000, "Exports": 125000, "Imports": 115000,
"Ending_Stocks": 36000, "Total_Distribution": 170000},
{"market_year": 2022, "Production": 172000, "Exports": 122000, "Imports": 112000,
"Ending_Stocks": 34000, "Total_Distribution": 168000},
]
MOCK_TOP_COUNTRIES = [
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022, "Production": 65000},
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, "Production": 30000},
{"country_name": "Colombia", "country_code": "CO", "market_year": 2022, "Production": 14000},
]
MOCK_STU_TREND = [
{"market_year": 2020, "Stock_to_Use_Ratio_pct": 21.2},
{"market_year": 2021, "Stock_to_Use_Ratio_pct": 21.1},
{"market_year": 2022, "Stock_to_Use_Ratio_pct": 20.2},
]
MOCK_BALANCE = [
{"market_year": 2020, "Production": 170000, "Total_Distribution": 165000, "Supply_Demand_Balance": 5000},
{"market_year": 2021, "Production": 175000, "Total_Distribution": 170000, "Supply_Demand_Balance": 5000},
{"market_year": 2022, "Production": 172000, "Total_Distribution": 168000, "Supply_Demand_Balance": 4000},
]
MOCK_YOY = [
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022,
"Production": 65000, "Production_YoY_pct": -3.5},
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022,
"Production": 30000, "Production_YoY_pct": 2.1},
]
MOCK_COMMODITIES = [
{"commodity_code": 711100, "commodity_name": "Coffee, Green"},
{"commodity_code": 222000, "commodity_name": "Soybeans"},
]
@pytest.fixture
def mock_analytics():
"""Patch all analytics query functions with mock data."""
with patch.object(analytics, "get_global_time_series", new_callable=AsyncMock,
return_value=MOCK_TIME_SERIES), \
patch.object(analytics, "get_top_countries", new_callable=AsyncMock,
return_value=MOCK_TOP_COUNTRIES), \
patch.object(analytics, "get_stock_to_use_trend", new_callable=AsyncMock,
return_value=MOCK_STU_TREND), \
patch.object(analytics, "get_supply_demand_balance", new_callable=AsyncMock,
return_value=MOCK_BALANCE), \
patch.object(analytics, "get_production_yoy_by_country", new_callable=AsyncMock,
return_value=MOCK_YOY), \
patch.object(analytics, "get_country_comparison", new_callable=AsyncMock,
return_value=[]), \
patch.object(analytics, "get_available_commodities", new_callable=AsyncMock,
return_value=MOCK_COMMODITIES), \
patch.object(analytics, "fetch_analytics", new_callable=AsyncMock,
return_value=[{"result": 1}]):
yield
def sign_payload(payload_bytes: bytes) -> str:
"""Return a dummy signature for Paddle webhook tests (Verifier is mocked)."""
return "ts=1234567890;h1=dummy_signature"

View File

@@ -9,10 +9,13 @@ from hypothesis import strategies as st
from beanflows.billing.routes import (
can_access_feature,
get_billing_customer,
get_subscription,
get_subscription_by_provider_id,
is_within_limits,
record_transaction,
update_subscription_status,
upsert_billing_customer,
upsert_subscription,
)
from beanflows.core import config
@@ -45,7 +48,6 @@ class TestUpsertSubscription:
user_id=test_user["id"],
plan="pro",
status="active",
provider_customer_id="cust_abc",
provider_subscription_id="sub_xyz",
current_period_end="2025-06-01T00:00:00Z",
)
@@ -53,39 +55,53 @@ class TestUpsertSubscription:
row = await get_subscription(test_user["id"])
assert row["plan"] == "pro"
assert row["status"] == "active"
assert row["paddle_customer_id"] == "cust_abc"
assert row["paddle_subscription_id"] == "sub_xyz"
assert row["provider_subscription_id"] == "sub_xyz"
assert row["current_period_end"] == "2025-06-01T00:00:00Z"
async def test_update_existing_subscription(self, db, test_user, create_subscription):
original_id = await create_subscription(
test_user["id"], plan="starter", status="active",
paddle_subscription_id="sub_old",
async def test_update_existing_by_provider_subscription_id(self, db, test_user):
"""upsert finds existing by provider_subscription_id, not user_id."""
await upsert_subscription(
user_id=test_user["id"],
plan="starter",
status="active",
provider_subscription_id="sub_same",
)
returned_id = await upsert_subscription(
user_id=test_user["id"],
plan="pro",
status="active",
provider_customer_id="cust_new",
provider_subscription_id="sub_new",
provider_subscription_id="sub_same",
)
assert returned_id == original_id
row = await get_subscription(test_user["id"])
assert row["plan"] == "pro"
assert row["provider_subscription_id"] == "sub_same"
assert row["paddle_subscription_id"] == "sub_new"
async def test_different_provider_id_creates_new(self, db, test_user):
"""Different provider_subscription_id creates a new row (multi-sub support)."""
await upsert_subscription(
user_id=test_user["id"],
plan="starter",
status="active",
provider_subscription_id="sub_first",
)
await upsert_subscription(
user_id=test_user["id"],
plan="pro",
status="active",
provider_subscription_id="sub_second",
)
from beanflows.core import fetch_all
rows = await fetch_all(
"SELECT * FROM subscriptions WHERE user_id = ? ORDER BY created_at",
(test_user["id"],),
)
assert len(rows) == 2
async def test_upsert_with_none_period_end(self, db, test_user):
await upsert_subscription(
user_id=test_user["id"],
plan="pro",
status="active",
provider_customer_id="cust_1",
provider_subscription_id="sub_1",
current_period_end=None,
)
@@ -93,6 +109,28 @@ class TestUpsertSubscription:
assert row["current_period_end"] is None
# ════════════════════════════════════════════════════════════
# upsert_billing_customer / get_billing_customer
# ════════════════════════════════════════════════════════════
class TestUpsertBillingCustomer:
async def test_creates_billing_customer(self, db, test_user):
await upsert_billing_customer(test_user["id"], "cust_abc")
row = await get_billing_customer(test_user["id"])
assert row is not None
assert row["provider_customer_id"] == "cust_abc"
async def test_updates_existing_customer(self, db, test_user):
await upsert_billing_customer(test_user["id"], "cust_old")
await upsert_billing_customer(test_user["id"], "cust_new")
row = await get_billing_customer(test_user["id"])
assert row["provider_customer_id"] == "cust_new"
async def test_get_returns_none_for_unknown_user(self, db):
row = await get_billing_customer(99999)
assert row is None
# ════════════════════════════════════════════════════════════
# get_subscription_by_provider_id
# ════════════════════════════════════════════════════════════
@@ -102,10 +140,8 @@ class TestGetSubscriptionByProviderId:
result = await get_subscription_by_provider_id("nonexistent")
assert result is None
async def test_finds_by_paddle_subscription_id(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_findme")
async def test_finds_by_provider_subscription_id(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], provider_subscription_id="sub_findme")
result = await get_subscription_by_provider_id("sub_findme")
assert result is not None
assert result["user_id"] == test_user["id"]
@@ -117,18 +153,14 @@ class TestGetSubscriptionByProviderId:
class TestUpdateSubscriptionStatus:
async def test_updates_status(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_upd")
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_upd")
await update_subscription_status("sub_upd", status="cancelled")
row = await get_subscription(test_user["id"])
assert row["status"] == "cancelled"
assert row["updated_at"] is not None
async def test_updates_extra_fields(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_extra")
await create_subscription(test_user["id"], provider_subscription_id="sub_extra")
await update_subscription_status(
"sub_extra",
status="active",
@@ -141,9 +173,7 @@ class TestUpdateSubscriptionStatus:
assert row["current_period_end"] == "2026-01-01T00:00:00Z"
async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_known", status="active")
await create_subscription(test_user["id"], provider_subscription_id="sub_known", status="active")
await update_subscription_status("sub_unknown", status="expired")
row = await get_subscription(test_user["id"])
assert row["status"] == "active" # unchanged
@@ -155,22 +185,22 @@ class TestUpdateSubscriptionStatus:
class TestCanAccessFeature:
async def test_no_subscription_gets_free_features(self, db, test_user):
assert await can_access_feature(test_user["id"], "dashboard") is True
assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is False
assert await can_access_feature(test_user["id"], "api") is False
async def test_active_pro_gets_all_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="active")
assert await can_access_feature(test_user["id"], "dashboard") is True
assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is True
assert await can_access_feature(test_user["id"], "api") is True
assert await can_access_feature(test_user["id"], "priority_support") is True
async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="starter", status="active")
assert await can_access_feature(test_user["id"], "dashboard") is True
assert await can_access_feature(test_user["id"], "basic") is True
assert await can_access_feature(test_user["id"], "export") is True
assert await can_access_feature(test_user["id"], "all_commodities") is False
assert await can_access_feature(test_user["id"], "api") is False
async def test_cancelled_still_has_features(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="cancelled")
@@ -183,7 +213,7 @@ class TestCanAccessFeature:
async def test_expired_falls_back_to_free(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="expired")
assert await can_access_feature(test_user["id"], "api") is False
assert await can_access_feature(test_user["id"], "dashboard") is True
assert await can_access_feature(test_user["id"], "basic") is True
async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="past_due")
@@ -203,30 +233,28 @@ class TestCanAccessFeature:
# ════════════════════════════════════════════════════════════
class TestIsWithinLimits:
async def test_free_user_no_api_calls(self, db, test_user):
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
async def test_free_user_within_limits(self, db, test_user):
assert await is_within_limits(test_user["id"], "items", 50) is True
async def test_free_user_commodity_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "commodities", 0) is True
assert await is_within_limits(test_user["id"], "commodities", 1) is False
async def test_free_user_at_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "items", 100) is False
async def test_free_user_history_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "history_years", 4) is True
assert await is_within_limits(test_user["id"], "history_years", 5) is False
async def test_free_user_over_limit(self, db, test_user):
assert await is_within_limits(test_user["id"], "items", 150) is False
async def test_pro_unlimited(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="active")
assert await is_within_limits(test_user["id"], "commodities", 999999) is True
assert await is_within_limits(test_user["id"], "items", 999999) is True
assert await is_within_limits(test_user["id"], "api_calls", 999999) is True
async def test_starter_limits(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="starter", status="active")
assert await is_within_limits(test_user["id"], "api_calls", 9999) is True
assert await is_within_limits(test_user["id"], "api_calls", 10000) is False
assert await is_within_limits(test_user["id"], "items", 999) is True
assert await is_within_limits(test_user["id"], "items", 1000) is False
async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription):
await create_subscription(test_user["id"], plan="pro", status="expired")
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
assert await is_within_limits(test_user["id"], "items", 100) is False
async def test_unknown_resource_returns_false(self, db, test_user):
assert await is_within_limits(test_user["id"], "unicorns", 0) is False
@@ -238,7 +266,7 @@ class TestIsWithinLimits:
# ════════════════════════════════════════════════════════════
STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"]
FEATURES = ["dashboard", "export", "api", "priority_support"]
FEATURES = ["basic", "export", "api", "priority_support"]
ACTIVE_STATUSES = {"active", "on_trial", "cancelled"}
@@ -282,9 +310,9 @@ async def test_plan_feature_matrix(db, test_user, create_subscription, plan, fea
@pytest.mark.parametrize("plan", PLANS)
@pytest.mark.parametrize("resource,at_limit", [
("commodities", 1),
("commodities", 65),
("api_calls", 0),
("items", 100),
("items", 1000),
("api_calls", 1000),
("api_calls", 10000),
])
async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit):
@@ -307,11 +335,11 @@ async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resou
# ════════════════════════════════════════════════════════════
class TestLimitsHypothesis:
@given(count=st.integers(min_value=0, max_value=100))
@given(count=st.integers(min_value=0, max_value=10000))
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
async def test_free_limit_boundary_commodities(self, db, test_user, count):
result = await is_within_limits(test_user["id"], "commodities", count)
assert result == (count < 1)
async def test_free_limit_boundary_items(self, db, test_user, count):
result = await is_within_limits(test_user["id"], "items", count)
assert result == (count < 100)
@given(count=st.integers(min_value=0, max_value=100000))
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
@@ -319,7 +347,56 @@ class TestLimitsHypothesis:
# Use upsert to avoid duplicate inserts across Hypothesis examples
await upsert_subscription(
user_id=test_user["id"], plan="pro", status="active",
provider_customer_id="cust_hyp", provider_subscription_id="sub_hyp",
provider_subscription_id="sub_hyp",
)
result = await is_within_limits(test_user["id"], "commodities", count)
result = await is_within_limits(test_user["id"], "items", count)
assert result is True
# ════════════════════════════════════════════════════════════
# record_transaction
# ════════════════════════════════════════════════════════════
class TestRecordTransaction:
async def test_inserts_transaction(self, db, test_user):
txn_id = await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_abc123",
type="payment",
amount_cents=2999,
currency="EUR",
status="completed",
)
assert txn_id is not None and txn_id > 0
from beanflows.core import fetch_one
row = await fetch_one(
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
("txn_abc123",),
)
assert row is not None
assert row["user_id"] == test_user["id"]
assert row["amount_cents"] == 2999
assert row["currency"] == "EUR"
assert row["status"] == "completed"
async def test_idempotent_on_duplicate_provider_id(self, db, test_user):
await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_dup",
amount_cents=1000,
)
# Second insert with same provider_transaction_id should be ignored
await record_transaction(
user_id=test_user["id"],
provider_transaction_id="txn_dup",
amount_cents=9999,
)
from beanflows.core import fetch_all
rows = await fetch_all(
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
("txn_dup",),
)
assert len(rows) == 1
assert rows[0]["amount_cents"] == 1000 # original value preserved

View File

@@ -0,0 +1,122 @@
"""
Tests for the billing event hook system.
"""
import pytest
from beanflows.billing.routes import _billing_hooks, _fire_hooks, on_billing_event
@pytest.fixture(autouse=True)
def clear_hooks():
"""Ensure hooks are clean before and after each test."""
_billing_hooks.clear()
yield
_billing_hooks.clear()
# ════════════════════════════════════════════════════════════
# Registration
# ════════════════════════════════════════════════════════════
class TestOnBillingEvent:
def test_registers_single_event(self):
@on_billing_event("subscription.activated")
async def my_hook(event_type, data):
pass
assert "subscription.activated" in _billing_hooks
assert my_hook in _billing_hooks["subscription.activated"]
def test_registers_multiple_events(self):
@on_billing_event("subscription.activated", "subscription.updated")
async def my_hook(event_type, data):
pass
assert my_hook in _billing_hooks["subscription.activated"]
assert my_hook in _billing_hooks["subscription.updated"]
def test_multiple_hooks_per_event(self):
@on_billing_event("subscription.activated")
async def hook_a(event_type, data):
pass
@on_billing_event("subscription.activated")
async def hook_b(event_type, data):
pass
assert len(_billing_hooks["subscription.activated"]) == 2
def test_decorator_returns_original_function(self):
@on_billing_event("test_event")
async def my_hook(event_type, data):
pass
assert my_hook.__name__ == "my_hook"
# ════════════════════════════════════════════════════════════
# Firing
# ════════════════════════════════════════════════════════════
class TestFireHooks:
async def test_fires_registered_hook(self):
calls = []
@on_billing_event("subscription.activated")
async def recorder(event_type, data):
calls.append((event_type, data))
await _fire_hooks("subscription.activated", {"id": "sub_123"})
assert len(calls) == 1
assert calls[0] == ("subscription.activated", {"id": "sub_123"})
async def test_no_hooks_registered_is_noop(self):
# Should not raise
await _fire_hooks("unregistered_event", {"id": "sub_123"})
async def test_fires_all_hooks_for_event(self):
calls = []
@on_billing_event("subscription.activated")
async def hook_a(event_type, data):
calls.append("a")
@on_billing_event("subscription.activated")
async def hook_b(event_type, data):
calls.append("b")
await _fire_hooks("subscription.activated", {})
assert calls == ["a", "b"]
# ════════════════════════════════════════════════════════════
# Error isolation
# ════════════════════════════════════════════════════════════
class TestHookErrorIsolation:
async def test_failing_hook_does_not_block_others(self):
calls = []
@on_billing_event("subscription.activated")
async def failing_hook(event_type, data):
raise RuntimeError("boom")
@on_billing_event("subscription.activated")
async def good_hook(event_type, data):
calls.append("ok")
# Should not raise despite first hook failing
await _fire_hooks("subscription.activated", {})
assert calls == ["ok"]
async def test_failing_hook_is_logged(self, caplog):
@on_billing_event("subscription.activated")
async def bad_hook(event_type, data):
raise ValueError("test error")
import logging
with caplog.at_level(logging.ERROR):
await _fire_hooks("subscription.activated", {})
assert "bad_hook" in caplog.text
assert "test error" in caplog.text

View File

@@ -1,12 +1,15 @@
"""
Route integration tests for Paddle billing endpoints.
External Paddle API calls mocked with respx.
"""
import json
import httpx
Paddle SDK calls mocked via mock_paddle_client fixture.
"""
from unittest.mock import MagicMock
import pytest
import respx
CHECKOUT_METHOD = "POST"
@@ -54,24 +57,16 @@ class TestCheckoutRoute:
assert response.status_code in (302, 303, 307)
@respx.mock
async def test_creates_checkout_session(self, auth_client, db, test_user):
respx.post("https://api.paddle.com/transactions").mock(
return_value=httpx.Response(200, json={
"data": {
"checkout": {
"url": "https://checkout.paddle.com/test_123"
}
}
})
)
async def test_creates_checkout_session(self, auth_client, db, test_user, mock_paddle_client):
mock_txn = MagicMock()
mock_txn.checkout.url = "https://checkout.paddle.com/test_123"
mock_paddle_client.transactions.create.return_value = mock_txn
response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False)
assert response.status_code in (302, 303, 307)
mock_paddle_client.transactions.create.assert_called_once()
async def test_invalid_plan_rejected(self, auth_client, db, test_user):
@@ -82,20 +77,13 @@ class TestCheckoutRoute:
@respx.mock
async def test_api_error_propagates(self, auth_client, db, test_user):
respx.post("https://api.paddle.com/transactions").mock(
return_value=httpx.Response(500, json={"error": "server error"})
)
with pytest.raises(httpx.HTTPStatusError):
async def test_api_error_propagates(self, auth_client, db, test_user, mock_paddle_client):
mock_paddle_client.transactions.create.side_effect = Exception("API error")
with pytest.raises(Exception, match="API error"):
await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}")
# ════════════════════════════════════════════════════════════
# Manage subscription / Portal
# ════════════════════════════════════════════════════════════
@@ -110,24 +98,18 @@ class TestManageRoute:
response = await auth_client.post("/billing/manage", follow_redirects=False)
assert response.status_code in (302, 303, 307)
@respx.mock
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
respx.get("https://api.paddle.com/subscriptions/sub_test").mock(
return_value=httpx.Response(200, json={
"data": {
"management_urls": {
"update_payment_method": "https://paddle.com/manage/test_123"
}
}
})
)
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
mock_sub = MagicMock()
mock_sub.management_urls.update_payment_method = "https://paddle.com/manage/test_123"
mock_paddle_client.subscriptions.get.return_value = mock_sub
response = await auth_client.post("/billing/manage", follow_redirects=False)
assert response.status_code in (302, 303, 307)
mock_paddle_client.subscriptions.get.assert_called_once_with("sub_test")
@@ -145,18 +127,14 @@ class TestCancelRoute:
response = await auth_client.post("/billing/cancel", follow_redirects=False)
assert response.status_code in (302, 303, 307)
@respx.mock
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription):
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
respx.post("https://api.paddle.com/subscriptions/sub_test/cancel").mock(
return_value=httpx.Response(200, json={"data": {}})
)
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
response = await auth_client.post("/billing/cancel", follow_redirects=False)
assert response.status_code in (302, 303, 307)
mock_paddle_client.subscriptions.cancel.assert_called_once()
@@ -167,8 +145,9 @@ class TestCancelRoute:
# subscription_required decorator
# ════════════════════════════════════════════════════════════
from beanflows.billing.routes import subscription_required
from quart import Blueprint
from quart import Blueprint # noqa: E402
from beanflows.auth.routes import subscription_required # noqa: E402
test_bp = Blueprint("test", __name__)

View File

@@ -5,13 +5,13 @@ Covers signature verification, event parsing, subscription lifecycle transitions
import json
import pytest
from conftest import make_webhook_payload, sign_payload
from hypothesis import HealthCheck, given
from hypothesis import settings as h_settings
from hypothesis import strategies as st
from beanflows.billing.routes import get_subscription
from conftest import make_webhook_payload, sign_payload
from beanflows.billing.routes import get_billing_customer, get_subscription
WEBHOOK_PATH = "/billing/webhook/paddle"
@@ -72,18 +72,19 @@ class TestWebhookSignature:
async def test_modified_payload_rejected(self, client, db, test_user):
# Paddle SDK Verifier handles tamper detection internally.
# We test signature rejection via test_invalid_signature_rejected above.
# This test verifies the Verifier is actually called by sending
# a payload with an explicitly bad signature.
payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"]))
payload_bytes = json.dumps(payload).encode()
sig = sign_payload(payload_bytes)
tampered = payload_bytes + b"extra"
# Paddle/LemonSqueezy: HMAC signature verification fails before JSON parsing
response = await client.post(
WEBHOOK_PATH,
data=tampered,
headers={SIG_HEADER: sig, "Content-Type": "application/json"},
data=payload_bytes,
headers={SIG_HEADER: "invalid_signature", "Content-Type": "application/json"},
)
assert response.status_code in (400, 401)
assert response.status_code == 400
async def test_empty_payload_rejected(self, client, db):
@@ -105,7 +106,7 @@ class TestWebhookSignature:
class TestWebhookSubscriptionActivated:
async def test_creates_subscription(self, client, db, test_user):
async def test_creates_subscription_and_billing_customer(self, client, db, test_user):
payload = make_webhook_payload(
"subscription.activated",
user_id=str(test_user["id"]),
@@ -126,10 +127,14 @@ class TestWebhookSubscriptionActivated:
assert sub["plan"] == "starter"
assert sub["status"] == "active"
bc = await get_billing_customer(test_user["id"])
assert bc is not None
assert bc["provider_customer_id"] == "ctm_test123"
class TestWebhookSubscriptionUpdated:
async def test_updates_subscription_status(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload(
"subscription.updated",
@@ -152,7 +157,7 @@ class TestWebhookSubscriptionUpdated:
class TestWebhookSubscriptionCanceled:
async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload(
"subscription.canceled",
@@ -174,7 +179,7 @@ class TestWebhookSubscriptionCanceled:
class TestWebhookSubscriptionPastDue:
async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription):
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
payload = make_webhook_payload(
"subscription.past_due",
@@ -209,7 +214,7 @@ class TestWebhookSubscriptionPastDue:
])
async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status):
if event_type != "subscription.activated":
await create_subscription(test_user["id"], paddle_subscription_id="sub_test456")
await create_subscription(test_user["id"], provider_subscription_id="sub_test456")
payload = make_webhook_payload(event_type, user_id=str(test_user["id"]))
payload_bytes = json.dumps(payload).encode()

242
web/tests/test_roles.py Normal file
View File

@@ -0,0 +1,242 @@
"""
Tests for role-based access control: role_required decorator, grant/revoke/ensure_admin_role,
and admin route protection.
"""
import pytest
from quart import Blueprint
from beanflows.auth.routes import (
ensure_admin_role,
grant_role,
revoke_role,
role_required,
)
from beanflows import core
# ════════════════════════════════════════════════════════════
# grant_role / revoke_role
# ════════════════════════════════════════════════════════════
class TestGrantRole:
async def test_grants_role(self, db, test_user):
await grant_role(test_user["id"], "admin")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ?",
(test_user["id"],),
)
assert row is not None
assert row["role"] == "admin"
async def test_idempotent(self, db, test_user):
await grant_role(test_user["id"], "admin")
await grant_role(test_user["id"], "admin")
rows = await core.fetch_all(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert len(rows) == 1
class TestRevokeRole:
async def test_revokes_existing_role(self, db, test_user):
await grant_role(test_user["id"], "admin")
await revoke_role(test_user["id"], "admin")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
async def test_noop_for_missing_role(self, db, test_user):
# Should not raise
await revoke_role(test_user["id"], "nonexistent")
# ════════════════════════════════════════════════════════════
# ensure_admin_role
# ════════════════════════════════════════════════════════════
class TestEnsureAdminRole:
async def test_grants_admin_for_listed_email(self, db, test_user):
core.config.ADMIN_EMAILS = ["test@example.com"]
try:
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is not None
finally:
core.config.ADMIN_EMAILS = []
async def test_skips_for_unlisted_email(self, db, test_user):
core.config.ADMIN_EMAILS = ["boss@example.com"]
try:
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
finally:
core.config.ADMIN_EMAILS = []
async def test_empty_admin_emails_grants_nothing(self, db, test_user):
core.config.ADMIN_EMAILS = []
await ensure_admin_role(test_user["id"], "test@example.com")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is None
async def test_case_insensitive_matching(self, db, test_user):
core.config.ADMIN_EMAILS = ["test@example.com"]
try:
await ensure_admin_role(test_user["id"], "Test@Example.COM")
row = await core.fetch_one(
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
(test_user["id"],),
)
assert row is not None
finally:
core.config.ADMIN_EMAILS = []
# ════════════════════════════════════════════════════════════
# role_required decorator
# ════════════════════════════════════════════════════════════
role_test_bp = Blueprint("role_test", __name__)
@role_test_bp.route("/admin-only")
@role_required("admin")
async def admin_only_route():
return "admin-ok", 200
@role_test_bp.route("/multi-role")
@role_required("admin", "editor")
async def multi_role_route():
return "multi-ok", 200
class TestRoleRequired:
@pytest.fixture
async def role_app(self, app):
app.register_blueprint(role_test_bp)
return app
@pytest.fixture
async def role_client(self, role_app):
async with role_app.test_client() as c:
yield c
async def test_redirects_unauthenticated(self, role_client, db):
response = await role_client.get("/admin-only", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_rejects_user_without_role(self, role_client, db, test_user):
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/admin-only", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_allows_user_with_matching_role(self, role_client, db, test_user):
await grant_role(test_user["id"], "admin")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/admin-only")
assert response.status_code == 200
async def test_multi_role_allows_any_match(self, role_client, db, test_user):
await grant_role(test_user["id"], "editor")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/multi-role")
assert response.status_code == 200
async def test_multi_role_rejects_none(self, role_client, db, test_user):
await grant_role(test_user["id"], "viewer")
async with role_client.session_transaction() as sess:
sess["user_id"] = test_user["id"]
response = await role_client.get("/multi-role", follow_redirects=False)
assert response.status_code in (302, 303, 307)
# ════════════════════════════════════════════════════════════
# Admin route protection
# ════════════════════════════════════════════════════════════
class TestAdminRouteProtection:
async def test_admin_index_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_admin_index_accessible_with_admin_role(self, admin_client, db):
response = await admin_client.get("/admin/")
assert response.status_code == 200
async def test_admin_users_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/users", follow_redirects=False)
assert response.status_code in (302, 303, 307)
async def test_admin_tasks_requires_admin_role(self, auth_client, db):
response = await auth_client.get("/admin/tasks", follow_redirects=False)
assert response.status_code in (302, 303, 307)
# ════════════════════════════════════════════════════════════
# Impersonation
# ════════════════════════════════════════════════════════════
class TestImpersonation:
async def test_impersonate_stores_admin_id(self, admin_client, db, test_user):
"""Impersonating stores admin's user_id in session['admin_impersonating']."""
# Create a second user to impersonate
now = "2025-01-01T00:00:00"
other_id = await core.execute(
"INSERT INTO users (email, name, created_at) VALUES (?, ?, ?)",
("other@example.com", "Other", now),
)
async with admin_client.session_transaction() as sess:
sess["csrf_token"] = "test_csrf"
response = await admin_client.post(
f"/admin/users/{other_id}/impersonate",
form={"csrf_token": "test_csrf"},
follow_redirects=False,
)
assert response.status_code in (302, 303, 307)
async with admin_client.session_transaction() as sess:
assert sess["user_id"] == other_id
assert sess["admin_impersonating"] == test_user["id"]
async def test_stop_impersonating_restores_admin(self, app, db, test_user, grant_role):
"""Stopping impersonation restores the admin's user_id."""
await grant_role(test_user["id"], "admin")
async with app.test_client() as c:
async with c.session_transaction() as sess:
sess["user_id"] = 999 # impersonated user
sess["admin_impersonating"] = test_user["id"]
sess["csrf_token"] = "test_csrf"
response = await c.post(
"/admin/stop-impersonating",
form={"csrf_token": "test_csrf"},
follow_redirects=False,
)
assert response.status_code in (302, 303, 307)
async with c.session_transaction() as sess:
assert sess["user_id"] == test_user["id"]
assert "admin_impersonating" not in sess