Merge branch 'copier-update'
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
# Changes here will be overwritten by Copier; NEVER EDIT MANUALLY
|
||||
_commit: v0.3.0
|
||||
_commit: v0.4.0
|
||||
_src_path: git@gitlab.com:deemanone/materia_saas_boilerplate.master.git
|
||||
author_email: hendrik@beanflows.coffee
|
||||
author_name: Hendrik Deeman
|
||||
|
||||
73
web/CHANGELOG.md
Normal file
73
web/CHANGELOG.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Changed
|
||||
- **Role-based access control**: `user_roles` table with `role_required()` decorator replaces password-based admin auth
|
||||
- **Admin is a real user**: admins authenticate via magic links; `ADMIN_EMAILS` env var auto-grants admin role on login
|
||||
- **Separated billing entities**: `billing_customers` table holds payment provider identity; `subscriptions` table holds only subscription state
|
||||
- **Multiple subscriptions per user**: dropped UNIQUE constraint on `subscriptions.user_id`; `upsert_subscription` finds by `provider_subscription_id`
|
||||
|
||||
### Added
|
||||
- Simple A/B testing with `@ab_test` decorator and optional Umami `data-tag` integration (`UMAMI_SCRIPT_URL` / `UMAMI_WEBSITE_ID` env vars)
|
||||
- `user_roles` table and `grant_role()` / `revoke_role()` / `ensure_admin_role()` functions
|
||||
- `billing_customers` table and `upsert_billing_customer()` / `get_billing_customer()` functions
|
||||
- `role_required(*roles)` decorator in auth
|
||||
- `is_admin` template context variable
|
||||
- Migration `0001_roles_and_billing_customers.py` for existing databases
|
||||
|
||||
### Removed
|
||||
- `ADMIN_PASSWORD` env var and password-based admin login
|
||||
- `provider_customer_id` column from `subscriptions` table
|
||||
- `admin/templates/admin/login.html`
|
||||
|
||||
### Changed
|
||||
- **Provider-agnostic schema**: generic `provider_customer_id` / `provider_subscription_id` columns replace provider-prefixed names (`stripe_customer_id`, `paddle_customer_id`, `lemonsqueezy_customer_id`) — eliminates all Jinja conditionals from schema, SQL helpers, and route code
|
||||
- **Consolidated `subscription_required` decorator**: single implementation in `auth/routes.py` supporting both plan and status checks, reads from eager-loaded `g.subscription` (zero extra queries)
|
||||
- **Eager-loaded `g.subscription`**: `load_user` in `app.py` now fetches user + subscription in a single JOIN; available in all routes and templates via `g.subscription`
|
||||
|
||||
### Added
|
||||
- `transactions` table for recording payment/refund events with idempotent `record_transaction()` helper
|
||||
- Billing event hook system: `on_billing_event()` decorator and `_fire_hooks()` for domain code to react to subscription changes; errors are logged and never cause webhook 500s
|
||||
|
||||
### Removed
|
||||
- Duplicate `subscription_required` decorator from `billing/routes.py` (consolidated in `auth/routes.py`)
|
||||
- `get_user_with_subscription()` from `auth/routes.py` (replaced by eager-loaded `g.subscription`)
|
||||
|
||||
### Changed
|
||||
- **Email SDK migration**: replaced raw httpx calls with official `resend` SDK in `core.py`
|
||||
- Added `from_addr` parameter to `send_email()` for multi-address support
|
||||
- Added `EMAIL_ADDRESSES` dict for named sender addresses (transactional, etc.)
|
||||
- **Paddle SDK migration**: replaced raw httpx calls with official `paddle-python-sdk` in `billing/routes.py`
|
||||
- Checkout, manage, cancel routes now use typed SDK methods (`PaddleClient`, `CreateTransaction`)
|
||||
- Webhook verification uses SDK's `Verifier` instead of hand-rolled HMAC
|
||||
- Added `PADDLE_ENVIRONMENT` config for sandbox/production toggling
|
||||
- Added `_paddle_client()` helper factory
|
||||
- **Dependencies**: `resend` replaces `httpx` for email; `paddle-python-sdk` replaces `httpx` for Paddle billing; `httpx` now only included for LemonSqueezy projects
|
||||
- Worker `send_email` task handler now passes through `from_addr`
|
||||
|
||||
### Added
|
||||
- `scripts/setup_paddle.py` — CLI script to create Paddle products/prices programmatically (Paddle projects only)
|
||||
|
||||
### Changed
|
||||
- **Pico CSS → Tailwind CSS v4** — full design system migration across all templates
|
||||
- Standalone Tailwind CLI binary (no Node.js) with `make css-build` / `make css-watch`
|
||||
- Brand theme with component classes (`.btn`, `.card`, `.form-input`, `.table`, `.badge`, `.flash`, etc.)
|
||||
- Self-hosted Commit Mono font for monospace data display
|
||||
- Docker multi-stage build: CSS compiled in dedicated stage before Python build
|
||||
|
||||
### Removed
|
||||
- Pico CSS CDN dependency
|
||||
- `custom.css` (replaced by Tailwind `input.css` with `@layer components`)
|
||||
- JetBrains Mono font (replaced by self-hosted Commit Mono)
|
||||
|
||||
### Fixed
|
||||
- Admin template collision: namespaced admin templates under `admin/` subdirectory to prevent Quart's template loader from resolving auth's `login.html` or dashboard's `index.html` instead of admin's
|
||||
- Admin user detail: `stripe_customer_id` hardcoded regardless of payment provider — now uses provider-aware Copier conditional (Stripe/Paddle/LemonSqueezy)
|
||||
|
||||
### Added
|
||||
- Initial project scaffolded from quart_saas_boilerplate
|
||||
@@ -1,3 +1,13 @@
|
||||
# CSS build stage (Tailwind standalone CLI, no Node.js)
|
||||
FROM debian:bookworm-slim AS css-build
|
||||
ADD https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 /usr/local/bin/tailwindcss
|
||||
RUN chmod +x /usr/local/bin/tailwindcss
|
||||
WORKDIR /app
|
||||
COPY src/ ./src/
|
||||
RUN tailwindcss -i ./src/beanflows/static/css/input.css \
|
||||
-o ./src/beanflows/static/css/output.css --minify
|
||||
|
||||
|
||||
# Build stage
|
||||
FROM python:3.12-slim AS build
|
||||
COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/
|
||||
@@ -15,6 +25,7 @@ RUN useradd -m -u 1000 appuser
|
||||
WORKDIR /app
|
||||
RUN mkdir -p /app/data && chown -R appuser:appuser /app
|
||||
COPY --from=build --chown=appuser:appuser /app .
|
||||
COPY --from=css-build /app/src/beanflows/static/css/output.css ./src/beanflows/static/css/output.css
|
||||
USER appuser
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DATABASE_PATH=/app/data/app.db
|
||||
|
||||
@@ -21,16 +21,16 @@ services:
|
||||
command: replicate -config /etc/litestream.yml
|
||||
volumes:
|
||||
- app-data:/app/data
|
||||
- ./beanflows/litestream.yml:/etc/litestream.yml:ro
|
||||
- ./litestream.yml:/etc/litestream.yml:ro
|
||||
|
||||
# ── Blue slot ─────────────────────────────────────────────
|
||||
|
||||
blue-app:
|
||||
profiles: ["blue"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
@@ -47,10 +47,10 @@ services:
|
||||
blue-worker:
|
||||
profiles: ["blue"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
command: python -m beanflows.worker
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
@@ -61,10 +61,10 @@ services:
|
||||
blue-scheduler:
|
||||
profiles: ["blue"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
command: python -m beanflows.worker scheduler
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
@@ -77,9 +77,9 @@ services:
|
||||
green-app:
|
||||
profiles: ["green"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
@@ -96,10 +96,10 @@ services:
|
||||
green-worker:
|
||||
profiles: ["green"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
command: python -m beanflows.worker
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
@@ -110,10 +110,10 @@ services:
|
||||
green-scheduler:
|
||||
profiles: ["green"]
|
||||
build:
|
||||
context: ./beanflows
|
||||
context: .
|
||||
restart: unless-stopped
|
||||
command: python -m beanflows.worker scheduler
|
||||
env_file: ./beanflows/.env
|
||||
env_file: ./.env
|
||||
environment:
|
||||
- DATABASE_PATH=/app/data/app.db
|
||||
volumes:
|
||||
|
||||
@@ -12,8 +12,9 @@ dependencies = [
|
||||
"aiosqlite>=0.19.0",
|
||||
"duckdb>=1.0.0",
|
||||
"httpx>=0.27.0",
|
||||
"resend>=2.22.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
|
||||
"paddle-python-sdk>=1.13.0",
|
||||
"itsdangerous>=2.1.0",
|
||||
"jinja2>=3.1.0",
|
||||
"hypercorn>=0.17.0",
|
||||
|
||||
@@ -52,14 +52,40 @@ def create_app() -> Quart:
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
return response
|
||||
|
||||
# Load current user before each request
|
||||
# Load current user + subscription + roles before each request
|
||||
@app.before_request
|
||||
async def load_user():
|
||||
g.user = None
|
||||
g.subscription = None
|
||||
user_id = session.get("user_id")
|
||||
if user_id:
|
||||
from .auth.routes import get_user_by_id
|
||||
g.user = await get_user_by_id(user_id)
|
||||
from .core import fetch_one as _fetch_one
|
||||
row = await _fetch_one(
|
||||
"""SELECT u.*,
|
||||
bc.provider_customer_id,
|
||||
(SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv,
|
||||
s.id AS sub_id, s.plan, s.status AS sub_status,
|
||||
s.provider_subscription_id, s.current_period_end
|
||||
FROM users u
|
||||
LEFT JOIN billing_customers bc ON bc.user_id = u.id
|
||||
LEFT JOIN subscriptions s ON s.id = (
|
||||
SELECT id FROM subscriptions
|
||||
WHERE user_id = u.id
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
)
|
||||
WHERE u.id = ? AND u.deleted_at IS NULL""",
|
||||
(user_id,),
|
||||
)
|
||||
if row:
|
||||
g.user = dict(row)
|
||||
g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else []
|
||||
if row["sub_id"]:
|
||||
g.subscription = {
|
||||
"id": row["sub_id"], "plan": row["plan"],
|
||||
"status": row["sub_status"],
|
||||
"provider_subscription_id": row["provider_subscription_id"],
|
||||
"current_period_end": row["current_period_end"],
|
||||
}
|
||||
|
||||
# Template context globals
|
||||
@app.context_processor
|
||||
@@ -68,8 +94,12 @@ def create_app() -> Quart:
|
||||
return {
|
||||
"config": config,
|
||||
"user": g.get("user"),
|
||||
"subscription": g.get("subscription"),
|
||||
"is_admin": "admin" in (g.get("user") or {}).get("roles", []),
|
||||
"now": datetime.utcnow(),
|
||||
"csrf_token": get_csrf_token,
|
||||
"ab_variant": getattr(g, "ab_variant", None),
|
||||
"ab_tag": getattr(g, "ab_tag", None),
|
||||
}
|
||||
|
||||
# Health check
|
||||
|
||||
@@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def get_user_with_subscription(user_id: int) -> dict | None:
|
||||
"""Get user with their active subscription info."""
|
||||
return await fetch_one(
|
||||
"""
|
||||
SELECT u.*, s.plan, s.status as sub_status, s.current_period_end
|
||||
FROM users u
|
||||
LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active'
|
||||
WHERE u.id = ? AND u.deleted_at IS NULL
|
||||
""",
|
||||
(user_id,)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Decorators
|
||||
# =============================================================================
|
||||
@@ -116,8 +103,53 @@ def login_required(f):
|
||||
return decorated
|
||||
|
||||
|
||||
def subscription_required(plans: list[str] = None):
|
||||
"""Require active subscription, optionally of specific plan(s)."""
|
||||
def role_required(*roles):
|
||||
"""Require user to have at least one of the given roles."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def decorated(*args, **kwargs):
|
||||
if not g.get("user"):
|
||||
await flash("Please sign in to continue.", "warning")
|
||||
return redirect(url_for("auth.login", next=request.path))
|
||||
user_roles = g.user.get("roles", [])
|
||||
if not any(r in user_roles for r in roles):
|
||||
await flash("You don't have permission to access that page.", "error")
|
||||
return redirect(url_for("dashboard.index"))
|
||||
return await f(*args, **kwargs)
|
||||
return decorated
|
||||
return decorator
|
||||
|
||||
|
||||
async def grant_role(user_id: int, role: str) -> None:
|
||||
"""Grant a role to a user (idempotent)."""
|
||||
await execute(
|
||||
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
|
||||
(user_id, role),
|
||||
)
|
||||
|
||||
|
||||
async def revoke_role(user_id: int, role: str) -> None:
|
||||
"""Revoke a role from a user."""
|
||||
await execute(
|
||||
"DELETE FROM user_roles WHERE user_id = ? AND role = ?",
|
||||
(user_id, role),
|
||||
)
|
||||
|
||||
|
||||
async def ensure_admin_role(user_id: int, email: str) -> None:
|
||||
"""Grant admin role if email is in ADMIN_EMAILS."""
|
||||
if email.lower() in config.ADMIN_EMAILS:
|
||||
await grant_role(user_id, "admin")
|
||||
|
||||
|
||||
def subscription_required(
|
||||
plans: list[str] = None,
|
||||
allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"),
|
||||
):
|
||||
"""Require active subscription, optionally of specific plan(s) and/or statuses.
|
||||
|
||||
Reads from g.subscription (eager-loaded in load_user) — zero extra queries.
|
||||
"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def decorated(*args, **kwargs):
|
||||
@@ -125,12 +157,12 @@ def subscription_required(plans: list[str] = None):
|
||||
await flash("Please sign in to continue.", "warning")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
user = await get_user_with_subscription(g.user["id"])
|
||||
if not user or not user.get("plan"):
|
||||
sub = g.get("subscription")
|
||||
if not sub or sub["status"] not in allowed:
|
||||
await flash("Please subscribe to access this feature.", "warning")
|
||||
return redirect(url_for("billing.pricing"))
|
||||
|
||||
if plans and user["plan"] not in plans:
|
||||
if plans and sub["plan"] not in plans:
|
||||
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
|
||||
return redirect(url_for("billing.pricing"))
|
||||
|
||||
@@ -246,6 +278,9 @@ async def verify():
|
||||
session.permanent = True
|
||||
session["user_id"] = token_data["user_id"]
|
||||
|
||||
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||
await ensure_admin_role(token_data["user_id"], token_data["email"])
|
||||
|
||||
await flash("Successfully signed in!", "success")
|
||||
|
||||
# Redirect to intended page or dashboard
|
||||
@@ -286,6 +321,9 @@ async def dev_login():
|
||||
session.permanent = True
|
||||
session["user_id"] = user_id
|
||||
|
||||
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||
await ensure_admin_role(user_id, email)
|
||||
|
||||
await flash(f"Dev login as {email}", "success")
|
||||
return redirect(url_for("dashboard.index"))
|
||||
|
||||
|
||||
@@ -2,16 +2,19 @@
|
||||
Core infrastructure: database, config, email, and shared utilities.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
import aiosqlite
|
||||
import resend
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from datetime import datetime, timedelta
|
||||
from contextvars import ContextVar
|
||||
from quart import request, session, g
|
||||
from quart import g, make_response, request, session
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# web/.env is three levels up from web/src/beanflows/core.py
|
||||
@@ -36,14 +39,22 @@ class Config:
|
||||
|
||||
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
|
||||
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
|
||||
PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||
PADDLE_PRICES: dict = {
|
||||
"starter": os.getenv("PADDLE_PRICE_STARTER", ""),
|
||||
"pro": os.getenv("PADDLE_PRICE_PRO", ""),
|
||||
}
|
||||
|
||||
UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "")
|
||||
UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "")
|
||||
|
||||
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
|
||||
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
|
||||
|
||||
ADMIN_EMAILS: list[str] = [
|
||||
e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip()
|
||||
]
|
||||
|
||||
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
|
||||
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
|
||||
|
||||
@@ -153,25 +164,32 @@ class transaction:
|
||||
# Email
|
||||
# =============================================================================
|
||||
|
||||
async def send_email(to: str, subject: str, html: str, text: str = None) -> bool:
|
||||
"""Send email via Resend API."""
|
||||
EMAIL_ADDRESSES = {
|
||||
"transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>",
|
||||
}
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str, subject: str, html: str, text: str = None, from_addr: str = None
|
||||
) -> bool:
|
||||
"""Send email via Resend SDK."""
|
||||
if not config.RESEND_API_KEY:
|
||||
print(f"[EMAIL] Would send to {to}: {subject}")
|
||||
return True
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.resend.com/emails",
|
||||
headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"},
|
||||
json={
|
||||
"from": config.EMAIL_FROM,
|
||||
resend.api_key = config.RESEND_API_KEY
|
||||
try:
|
||||
resend.Emails.send({
|
||||
"from": from_addr or config.EMAIL_FROM,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
"text": text or html,
|
||||
},
|
||||
)
|
||||
return response.status_code == 200
|
||||
})
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[EMAIL] Error sending to {to}: {e}")
|
||||
return False
|
||||
|
||||
# =============================================================================
|
||||
# CSRF Protection
|
||||
@@ -294,13 +312,11 @@ def setup_request_id(app):
|
||||
# Webhook Signature Verification
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
|
||||
"""Verify HMAC-SHA256 webhook signature."""
|
||||
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
|
||||
return hmac.compare_digest(signature, expected)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Soft Delete Helpers
|
||||
# =============================================================================
|
||||
@@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int:
|
||||
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
|
||||
(cutoff,)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# A/B Testing
|
||||
# =============================================================================
|
||||
|
||||
def ab_test(experiment: str, variants: tuple = ("control", "treatment")):
|
||||
"""Assign visitor to an A/B test variant via cookie, tag Umami pageviews."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapper(*args, **kwargs):
|
||||
cookie_key = f"ab_{experiment}"
|
||||
assigned = request.cookies.get(cookie_key)
|
||||
if assigned not in variants:
|
||||
assigned = random.choice(variants)
|
||||
|
||||
g.ab_variant = assigned
|
||||
g.ab_tag = f"{experiment}-{assigned}"
|
||||
|
||||
response = await make_response(await f(*args, **kwargs))
|
||||
response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60)
|
||||
return response
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,47 +1,91 @@
|
||||
"""
|
||||
Simple migration runner. Runs schema.sql against the database.
|
||||
"""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
Sequential migration runner.
|
||||
|
||||
Replays all migrations in order. All databases — fresh and existing —
|
||||
go through the same path. No schema.sql fast-path.
|
||||
|
||||
- Scans versions/ for NNNN_*.py files and runs unapplied ones in order
|
||||
- Each migration has an up(conn) function receiving an uncommitted connection
|
||||
- All pending migrations share a single transaction (batch atomicity)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
VERSIONS_DIR = Path(__file__).parent / "versions"
|
||||
VERSION_RE = re.compile(r"^(\d{4})_.+\.py$")
|
||||
|
||||
def migrate():
|
||||
"""Run migrations."""
|
||||
# Get database path from env or default
|
||||
# Derived from the package path: …/src/<slug>/migrations/migrate.py
|
||||
_PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject"
|
||||
|
||||
|
||||
def _discover_versions():
|
||||
"""Return sorted list of version file stems."""
|
||||
if not VERSIONS_DIR.is_dir():
|
||||
return []
|
||||
versions = []
|
||||
for f in sorted(VERSIONS_DIR.iterdir()):
|
||||
if VERSION_RE.match(f.name):
|
||||
versions.append(f.stem)
|
||||
return versions
|
||||
|
||||
|
||||
def migrate(db_path=None):
|
||||
if db_path is None:
|
||||
db_path = os.getenv("DATABASE_PATH", "data/app.db")
|
||||
|
||||
# Ensure directory exists
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
schema = schema_path.read_text()
|
||||
|
||||
# Connect and execute
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
# Enable WAL mode
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
# Run schema
|
||||
conn.executescript(schema)
|
||||
# Ensure tracking table exists before anything else
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
print(f"✓ Migrations complete: {db_path}")
|
||||
versions = _discover_versions()
|
||||
applied = {
|
||||
row[0]
|
||||
for row in conn.execute("SELECT name FROM _migrations").fetchall()
|
||||
}
|
||||
pending = [v for v in versions if v not in applied]
|
||||
|
||||
# Show tables
|
||||
if pending:
|
||||
for name in pending:
|
||||
print(f" Applying {name}...")
|
||||
mod = importlib.import_module(
|
||||
f"{_PACKAGE}.migrations.versions.{name}"
|
||||
)
|
||||
mod.up(conn)
|
||||
conn.execute(
|
||||
"INSERT INTO _migrations (name) VALUES (?)", (name,)
|
||||
)
|
||||
conn.commit()
|
||||
print(f"✓ Applied {len(pending)} migration(s): {db_path}")
|
||||
else:
|
||||
print(f"✓ All migrations already applied: {db_path}")
|
||||
|
||||
# Show tables (excluding internal sqlite/fts tables)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
" AND name NOT LIKE 'sqlite_%'"
|
||||
" ORDER BY name"
|
||||
)
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
print(f" Tables: {', '.join(tables)}")
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
-- BeanFlows Database Schema
|
||||
-- Run with: python -m beanflows.migrations.migrate
|
||||
|
||||
-- Migration tracking
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
-- Users
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens (
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
|
||||
|
||||
-- User Roles
|
||||
CREATE TABLE IF NOT EXISTS user_roles (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
granted_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
UNIQUE(user_id, role)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role);
|
||||
|
||||
-- Billing Customers (payment provider identity, separate from subscriptions)
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
||||
provider_customer_id TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id);
|
||||
|
||||
-- Subscriptions
|
||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
plan TEXT NOT NULL DEFAULT 'free',
|
||||
status TEXT NOT NULL DEFAULT 'free',
|
||||
|
||||
paddle_customer_id TEXT,
|
||||
paddle_subscription_id TEXT,
|
||||
|
||||
provider_subscription_id TEXT,
|
||||
current_period_end TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id);
|
||||
-- Transactions
|
||||
CREATE TABLE IF NOT EXISTS transactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
subscription_id INTEGER REFERENCES subscriptions(id),
|
||||
provider_transaction_id TEXT UNIQUE,
|
||||
type TEXT NOT NULL DEFAULT 'payment',
|
||||
amount_cents INTEGER,
|
||||
currency TEXT DEFAULT 'USD',
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
metadata TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id);
|
||||
|
||||
-- API Keys
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
@@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
|
||||
|
||||
-- Items (example domain entity - replace with your domain)
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
name TEXT NOT NULL,
|
||||
data TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT,
|
||||
deleted_at TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at);
|
||||
|
||||
-- Full-text search for items (optional)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
|
||||
name,
|
||||
data,
|
||||
content='items',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- FTS triggers
|
||||
CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN
|
||||
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN
|
||||
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN
|
||||
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||
END;
|
||||
0
web/src/beanflows/scripts/__init__.py
Normal file
0
web/src/beanflows/scripts/__init__.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
@@ -0,0 +1,92 @@
|
||||
|
||||
"""
|
||||
Create Paddle products and prices for BeanFlows.
|
||||
|
||||
Run once per environment (sandbox, then production).
|
||||
Prints resulting price IDs as a .env snippet.
|
||||
|
||||
Usage:
|
||||
uv run python -m beanflows.scripts.setup_paddle
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from paddle_billing import Client as PaddleClient
|
||||
from paddle_billing import Environment, Options
|
||||
from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory
|
||||
from paddle_billing.Resources.Prices.Operations import CreatePrice
|
||||
from paddle_billing.Resources.Products.Operations import CreateProduct
|
||||
|
||||
load_dotenv()
|
||||
|
||||
PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "")
|
||||
PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||
|
||||
if not PADDLE_API_KEY:
|
||||
print("ERROR: Set PADDLE_API_KEY in .env first")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
PRODUCTS = [
|
||||
# Subscriptions
|
||||
{
|
||||
"name": "Starter",
|
||||
"env_key": "PADDLE_PRICE_STARTER",
|
||||
"price": 900,
|
||||
"currency": CurrencyCode.USD,
|
||||
"interval": "month",
|
||||
"type": "subscription",
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
"env_key": "PADDLE_PRICE_PRO",
|
||||
"price": 2900,
|
||||
"currency": CurrencyCode.USD,
|
||||
"interval": "month",
|
||||
"type": "subscription",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION
|
||||
paddle = PaddleClient(PADDLE_API_KEY, options=Options(env))
|
||||
|
||||
print(f"Creating products in {PADDLE_ENVIRONMENT}...\n")
|
||||
|
||||
env_lines = []
|
||||
|
||||
for spec in PRODUCTS:
|
||||
# Create product
|
||||
product = paddle.products.create(CreateProduct(
|
||||
name=spec["name"],
|
||||
tax_category=TaxCategory.Standard,
|
||||
))
|
||||
print(f" Product: {spec['name']} -> {product.id}")
|
||||
|
||||
# Create price
|
||||
price_kwargs = {
|
||||
"description": spec["name"],
|
||||
"product_id": product.id,
|
||||
"unit_price": Money(str(spec["price"]), spec["currency"]),
|
||||
}
|
||||
|
||||
if spec["type"] == "subscription":
|
||||
from paddle_billing.Entities.Shared import TimePeriod
|
||||
price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1)
|
||||
|
||||
price = paddle.prices.create(CreatePrice(**price_kwargs))
|
||||
print(f" Price: {spec['env_key']} = {price.id}")
|
||||
|
||||
env_lines.append(f"{spec['env_key']}={price.id}")
|
||||
|
||||
print("\n# --- .env snippet ---")
|
||||
for line in env_lines:
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
Binary file not shown.
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
Binary file not shown.
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
@@ -0,0 +1,90 @@
|
||||
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||
This license is copied below, and is also available with a FAQ at:
|
||||
http://scripts.sil.org/OFL
|
||||
|
||||
-----------------------------------------------------------
|
||||
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||
-----------------------------------------------------------
|
||||
|
||||
PREAMBLE
|
||||
The goals of the Open Font License (OFL) are to stimulate worldwide
|
||||
development of collaborative font projects, to support the font creation
|
||||
efforts of academic and linguistic communities, and to provide a free and
|
||||
open framework in which fonts may be shared and improved in partnership
|
||||
with others.
|
||||
|
||||
The OFL allows the licensed fonts to be used, studied, modified and
|
||||
redistributed freely as long as they are not sold by themselves. The
|
||||
fonts, including any derivative works, can be bundled, embedded,
|
||||
redistributed and/or sold with any software provided that any reserved
|
||||
names are not used by derivative works. The fonts and derivatives,
|
||||
however, cannot be released under any other type of license. The
|
||||
requirement for fonts to remain under this license does not apply
|
||||
to any document created using the fonts or their derivatives.
|
||||
|
||||
DEFINITIONS
|
||||
"Font Software" refers to the set of files released by the Copyright
|
||||
Holder(s) under this license and clearly marked as such. This may
|
||||
include source files, build scripts and documentation.
|
||||
|
||||
"Reserved Font Name" refers to any names specified as such after the
|
||||
copyright statement(s).
|
||||
|
||||
"Original Version" refers to the collection of Font Software components as
|
||||
distributed by the Copyright Holder(s).
|
||||
|
||||
"Modified Version" refers to any derivative made by adding to, deleting,
|
||||
or substituting -- in part or in whole -- any of the components of the
|
||||
Original Version, by changing formats or by porting the Font Software to a
|
||||
new environment.
|
||||
|
||||
"Author" refers to any designer, engineer, programmer, technical
|
||||
writer or other person who contributed to the Font Software.
|
||||
|
||||
PERMISSION & CONDITIONS
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of the Font Software, to use, study, copy, merge, embed, modify,
|
||||
redistribute, and sell modified and unmodified copies of the Font
|
||||
Software, subject to the following conditions:
|
||||
|
||||
1) Neither the Font Software nor any of its individual components,
|
||||
in Original or Modified Versions, may be sold by itself.
|
||||
|
||||
2) Original or Modified Versions of the Font Software may be bundled,
|
||||
redistributed and/or sold with any software, provided that each copy
|
||||
contains the above copyright notice and this license. These can be
|
||||
included either as stand-alone text files, human-readable headers or
|
||||
in the appropriate machine-readable metadata fields within text or
|
||||
binary files as long as those fields can be easily viewed by the user.
|
||||
|
||||
3) No Modified Version of the Font Software may use the Reserved Font
|
||||
Name(s) unless explicit written permission is granted by the corresponding
|
||||
Copyright Holder. This restriction only applies to the primary font name as
|
||||
presented to the users.
|
||||
|
||||
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
|
||||
Software shall not be used to promote, endorse or advertise any
|
||||
Modified Version, except to acknowledge the contribution(s) of the
|
||||
Copyright Holder(s) and the Author(s) or with their explicit written
|
||||
permission.
|
||||
|
||||
5) The Font Software, modified or unmodified, in part or in whole,
|
||||
must be distributed entirely under this license, and must not be
|
||||
distributed under any other license. The requirement for fonts to
|
||||
remain under this license does not apply to any document created
|
||||
using the Font Software.
|
||||
|
||||
TERMINATION
|
||||
This license becomes null and void if any of the above conditions are
|
||||
not met.
|
||||
|
||||
DISCLAIMER
|
||||
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
|
||||
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
|
||||
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
@@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email
|
||||
HANDLERS: dict[str, callable] = {}
|
||||
|
||||
|
||||
def _email_wrap(body: str) -> str:
|
||||
"""Wrap email body in a branded layout with inline CSS."""
|
||||
return f"""\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head><meta charset="utf-8"></head>
|
||||
<body style="margin:0;padding:0;background-color:#F8FAFC;font-family:'Inter',Helvetica,Arial,sans-serif;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" style="background-color:#F8FAFC;padding:40px 0;">
|
||||
<tr><td align="center">
|
||||
<table width="480" cellpadding="0" cellspacing="0" style="background-color:#FFFFFF;border-radius:8px;border:1px solid #E2E8F0;overflow:hidden;">
|
||||
<!-- Header -->
|
||||
<tr><td style="background-color:#0F172A;padding:24px 32px;">
|
||||
<span style="color:#FFFFFF;font-size:18px;font-weight:700;letter-spacing:-0.02em;">{config.APP_NAME}</span>
|
||||
</td></tr>
|
||||
<!-- Body -->
|
||||
<tr><td style="padding:32px;color:#475569;font-size:15px;line-height:1.6;">
|
||||
{body}
|
||||
</td></tr>
|
||||
<!-- Footer -->
|
||||
<tr><td style="padding:20px 32px;border-top:1px solid #E2E8F0;text-align:center;">
|
||||
<span style="color:#94A3B8;font-size:12px;">© {config.APP_NAME} · You received this because you have an account.</span>
|
||||
</td></tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
def _email_button(url: str, label: str) -> str:
|
||||
"""Render a branded CTA button for email."""
|
||||
return (
|
||||
f'<table cellpadding="0" cellspacing="0" style="margin:24px 0;">'
|
||||
f'<tr><td style="background-color:#3B82F6;border-radius:6px;text-align:center;">'
|
||||
f'<a href="{url}" style="display:inline-block;padding:12px 28px;'
|
||||
f'color:#FFFFFF;font-size:15px;font-weight:600;text-decoration:none;">'
|
||||
f'{label}</a></td></tr></table>'
|
||||
)
|
||||
|
||||
|
||||
def task(name: str):
|
||||
"""Decorator to register a task handler."""
|
||||
def decorator(f):
|
||||
@@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None:
|
||||
subject=payload["subject"],
|
||||
html=payload["html"],
|
||||
text=payload.get("text"),
|
||||
from_addr=payload.get("from_addr"),
|
||||
)
|
||||
|
||||
|
||||
@@ -107,34 +148,36 @@ async def handle_send_magic_link(payload: dict) -> None:
|
||||
"""Send magic link email."""
|
||||
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
|
||||
|
||||
html = f"""
|
||||
<h2>Sign in to {config.APP_NAME}</h2>
|
||||
<p>Click the link below to sign in:</p>
|
||||
<p><a href="{link}">{link}</a></p>
|
||||
<p>This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>
|
||||
<p>If you didn't request this, you can safely ignore this email.</p>
|
||||
"""
|
||||
body = (
|
||||
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Sign in to {config.APP_NAME}</h2>'
|
||||
f"<p>Click the button below to sign in. This link expires in "
|
||||
f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>"
|
||||
f"{_email_button(link, 'Sign In')}"
|
||||
f'<p style="font-size:13px;color:#94A3B8;">If the button doesn\'t work, copy and paste this URL into your browser:</p>'
|
||||
f'<p style="font-size:13px;color:#94A3B8;word-break:break-all;">{link}</p>'
|
||||
f'<p style="font-size:13px;color:#94A3B8;">If you didn\'t request this, you can safely ignore this email.</p>'
|
||||
)
|
||||
|
||||
await send_email(
|
||||
to=payload["email"],
|
||||
subject=f"Sign in to {config.APP_NAME}",
|
||||
html=html,
|
||||
html=_email_wrap(body),
|
||||
)
|
||||
|
||||
|
||||
@task("send_welcome")
|
||||
async def handle_send_welcome(payload: dict) -> None:
|
||||
"""Send welcome email to new user."""
|
||||
html = f"""
|
||||
<h2>Welcome to {config.APP_NAME}!</h2>
|
||||
<p>Thanks for signing up. We're excited to have you.</p>
|
||||
<p><a href="{config.BASE_URL}/dashboard">Go to your dashboard</a></p>
|
||||
"""
|
||||
body = (
|
||||
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Welcome to {config.APP_NAME}!</h2>'
|
||||
f"<p>Thanks for signing up. We're excited to have you.</p>"
|
||||
f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}'
|
||||
)
|
||||
|
||||
await send_email(
|
||||
to=payload["email"],
|
||||
subject=f"Welcome to {config.APP_NAME}",
|
||||
html=html,
|
||||
html=_email_wrap(body),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,9 +10,11 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from beanflows import analytics, core
|
||||
|
||||
from beanflows import core
|
||||
from beanflows.app import create_app
|
||||
|
||||
|
||||
SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql"
|
||||
|
||||
|
||||
@@ -44,9 +46,7 @@ async def db():
|
||||
async def app(db):
|
||||
"""Quart app with DB already initialized (init_db/close_db patched to no-op)."""
|
||||
with patch.object(core, "init_db", new_callable=AsyncMock), \
|
||||
patch.object(core, "close_db", new_callable=AsyncMock), \
|
||||
patch.object(analytics, "open_analytics_db"), \
|
||||
patch.object(analytics, "close_analytics_db"):
|
||||
patch.object(core, "close_db", new_callable=AsyncMock):
|
||||
application = create_app()
|
||||
application.config["TESTING"] = True
|
||||
yield application
|
||||
@@ -92,22 +92,17 @@ def create_subscription(db):
|
||||
user_id: int,
|
||||
plan: str = "pro",
|
||||
status: str = "active",
|
||||
|
||||
paddle_customer_id: str = "ctm_test123",
|
||||
paddle_subscription_id: str = "sub_test456",
|
||||
|
||||
provider_subscription_id: str = "sub_test456",
|
||||
current_period_end: str = "2025-03-01T00:00:00Z",
|
||||
) -> int:
|
||||
now = datetime.utcnow().isoformat()
|
||||
async with db.execute(
|
||||
|
||||
"""INSERT INTO subscriptions
|
||||
(user_id, plan, status, paddle_customer_id,
|
||||
paddle_subscription_id, current_period_end, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(user_id, plan, status, paddle_customer_id, paddle_subscription_id,
|
||||
(user_id, plan, status,
|
||||
provider_subscription_id, current_period_end, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||
(user_id, plan, status, provider_subscription_id,
|
||||
current_period_end, now, now),
|
||||
|
||||
) as cursor:
|
||||
sub_id = cursor.lastrowid
|
||||
await db.commit()
|
||||
@@ -115,6 +110,48 @@ def create_subscription(db):
|
||||
return _create
|
||||
|
||||
|
||||
# ── Billing Customers ───────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def create_billing_customer(db):
|
||||
"""Factory: create a billing_customers row for a user."""
|
||||
async def _create(user_id: int, provider_customer_id: str = "cust_test123") -> int:
|
||||
async with db.execute(
|
||||
"""INSERT INTO billing_customers (user_id, provider_customer_id)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""",
|
||||
(user_id, provider_customer_id),
|
||||
) as cursor:
|
||||
row_id = cursor.lastrowid
|
||||
await db.commit()
|
||||
return row_id
|
||||
return _create
|
||||
|
||||
|
||||
# ── Roles ───────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def grant_role(db):
|
||||
"""Factory: grant a role to a user."""
|
||||
async def _grant(user_id: int, role: str) -> None:
|
||||
await db.execute(
|
||||
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
|
||||
(user_id, role),
|
||||
)
|
||||
await db.commit()
|
||||
return _grant
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def admin_client(app, test_user, grant_role):
|
||||
"""Test client with admin role and session['user_id'] pre-set."""
|
||||
await grant_role(test_user["id"], "admin")
|
||||
async with app.test_client() as c:
|
||||
async with c.session_transaction() as sess:
|
||||
sess["user_id"] = test_user["id"]
|
||||
yield c
|
||||
|
||||
|
||||
# ── Config ───────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -127,6 +164,7 @@ def patch_config():
|
||||
|
||||
"PADDLE_API_KEY": "test_api_key_123",
|
||||
"PADDLE_WEBHOOK_SECRET": "whsec_test_secret",
|
||||
"PADDLE_ENVIRONMENT": "sandbox",
|
||||
"PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"},
|
||||
|
||||
"BASE_URL": "http://localhost:5000",
|
||||
@@ -147,6 +185,32 @@ def patch_config():
|
||||
# ── Webhook helpers ──────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_paddle_verifier(monkeypatch):
|
||||
"""Mock Paddle's webhook Verifier to accept test payloads."""
|
||||
def mock_verify(self, payload, secret, signature):
|
||||
if not signature or signature == "invalid_signature":
|
||||
raise ValueError("Invalid signature")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"paddle_billing.Notifications.Verifier.verify",
|
||||
mock_verify,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paddle_client(monkeypatch):
|
||||
"""Mock _paddle_client() to return a fake PaddleClient."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
mock_client = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
"beanflows.billing.routes._paddle_client",
|
||||
lambda: mock_client,
|
||||
)
|
||||
return mock_client
|
||||
|
||||
|
||||
def make_webhook_payload(
|
||||
event_type: str,
|
||||
subscription_id: str = "sub_test456",
|
||||
@@ -172,76 +236,8 @@ def make_webhook_payload(
|
||||
}
|
||||
|
||||
|
||||
def sign_payload(payload_bytes: bytes, secret: str = "whsec_test_secret") -> str:
|
||||
"""Compute HMAC-SHA256 signature for a webhook payload."""
|
||||
return hmac.new(secret.encode(), payload_bytes, hashlib.sha256).hexdigest()
|
||||
|
||||
|
||||
# ── Analytics mock data ──────────────────────────────────────
|
||||
|
||||
MOCK_TIME_SERIES = [
|
||||
{"market_year": 2018, "Production": 165000, "Exports": 115000, "Imports": 105000,
|
||||
"Ending_Stocks": 33000, "Total_Distribution": 160000},
|
||||
{"market_year": 2019, "Production": 168000, "Exports": 118000, "Imports": 108000,
|
||||
"Ending_Stocks": 34000, "Total_Distribution": 163000},
|
||||
{"market_year": 2020, "Production": 170000, "Exports": 120000, "Imports": 110000,
|
||||
"Ending_Stocks": 35000, "Total_Distribution": 165000},
|
||||
{"market_year": 2021, "Production": 175000, "Exports": 125000, "Imports": 115000,
|
||||
"Ending_Stocks": 36000, "Total_Distribution": 170000},
|
||||
{"market_year": 2022, "Production": 172000, "Exports": 122000, "Imports": 112000,
|
||||
"Ending_Stocks": 34000, "Total_Distribution": 168000},
|
||||
]
|
||||
|
||||
MOCK_TOP_COUNTRIES = [
|
||||
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022, "Production": 65000},
|
||||
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, "Production": 30000},
|
||||
{"country_name": "Colombia", "country_code": "CO", "market_year": 2022, "Production": 14000},
|
||||
]
|
||||
|
||||
MOCK_STU_TREND = [
|
||||
{"market_year": 2020, "Stock_to_Use_Ratio_pct": 21.2},
|
||||
{"market_year": 2021, "Stock_to_Use_Ratio_pct": 21.1},
|
||||
{"market_year": 2022, "Stock_to_Use_Ratio_pct": 20.2},
|
||||
]
|
||||
|
||||
MOCK_BALANCE = [
|
||||
{"market_year": 2020, "Production": 170000, "Total_Distribution": 165000, "Supply_Demand_Balance": 5000},
|
||||
{"market_year": 2021, "Production": 175000, "Total_Distribution": 170000, "Supply_Demand_Balance": 5000},
|
||||
{"market_year": 2022, "Production": 172000, "Total_Distribution": 168000, "Supply_Demand_Balance": 4000},
|
||||
]
|
||||
|
||||
MOCK_YOY = [
|
||||
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022,
|
||||
"Production": 65000, "Production_YoY_pct": -3.5},
|
||||
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022,
|
||||
"Production": 30000, "Production_YoY_pct": 2.1},
|
||||
]
|
||||
|
||||
MOCK_COMMODITIES = [
|
||||
{"commodity_code": 711100, "commodity_name": "Coffee, Green"},
|
||||
{"commodity_code": 222000, "commodity_name": "Soybeans"},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_analytics():
|
||||
"""Patch all analytics query functions with mock data."""
|
||||
with patch.object(analytics, "get_global_time_series", new_callable=AsyncMock,
|
||||
return_value=MOCK_TIME_SERIES), \
|
||||
patch.object(analytics, "get_top_countries", new_callable=AsyncMock,
|
||||
return_value=MOCK_TOP_COUNTRIES), \
|
||||
patch.object(analytics, "get_stock_to_use_trend", new_callable=AsyncMock,
|
||||
return_value=MOCK_STU_TREND), \
|
||||
patch.object(analytics, "get_supply_demand_balance", new_callable=AsyncMock,
|
||||
return_value=MOCK_BALANCE), \
|
||||
patch.object(analytics, "get_production_yoy_by_country", new_callable=AsyncMock,
|
||||
return_value=MOCK_YOY), \
|
||||
patch.object(analytics, "get_country_comparison", new_callable=AsyncMock,
|
||||
return_value=[]), \
|
||||
patch.object(analytics, "get_available_commodities", new_callable=AsyncMock,
|
||||
return_value=MOCK_COMMODITIES), \
|
||||
patch.object(analytics, "fetch_analytics", new_callable=AsyncMock,
|
||||
return_value=[{"result": 1}]):
|
||||
yield
|
||||
def sign_payload(payload_bytes: bytes) -> str:
|
||||
"""Return a dummy signature for Paddle webhook tests (Verifier is mocked)."""
|
||||
return "ts=1234567890;h1=dummy_signature"
|
||||
|
||||
|
||||
|
||||
@@ -9,10 +9,13 @@ from hypothesis import strategies as st
|
||||
from beanflows.billing.routes import (
|
||||
|
||||
can_access_feature,
|
||||
get_billing_customer,
|
||||
get_subscription,
|
||||
get_subscription_by_provider_id,
|
||||
is_within_limits,
|
||||
record_transaction,
|
||||
update_subscription_status,
|
||||
upsert_billing_customer,
|
||||
upsert_subscription,
|
||||
)
|
||||
from beanflows.core import config
|
||||
@@ -45,7 +48,6 @@ class TestUpsertSubscription:
|
||||
user_id=test_user["id"],
|
||||
plan="pro",
|
||||
status="active",
|
||||
provider_customer_id="cust_abc",
|
||||
provider_subscription_id="sub_xyz",
|
||||
current_period_end="2025-06-01T00:00:00Z",
|
||||
)
|
||||
@@ -53,39 +55,53 @@ class TestUpsertSubscription:
|
||||
row = await get_subscription(test_user["id"])
|
||||
assert row["plan"] == "pro"
|
||||
assert row["status"] == "active"
|
||||
|
||||
assert row["paddle_customer_id"] == "cust_abc"
|
||||
assert row["paddle_subscription_id"] == "sub_xyz"
|
||||
|
||||
assert row["provider_subscription_id"] == "sub_xyz"
|
||||
assert row["current_period_end"] == "2025-06-01T00:00:00Z"
|
||||
|
||||
async def test_update_existing_subscription(self, db, test_user, create_subscription):
|
||||
original_id = await create_subscription(
|
||||
test_user["id"], plan="starter", status="active",
|
||||
|
||||
paddle_subscription_id="sub_old",
|
||||
|
||||
async def test_update_existing_by_provider_subscription_id(self, db, test_user):
|
||||
"""upsert finds existing by provider_subscription_id, not user_id."""
|
||||
await upsert_subscription(
|
||||
user_id=test_user["id"],
|
||||
plan="starter",
|
||||
status="active",
|
||||
provider_subscription_id="sub_same",
|
||||
)
|
||||
returned_id = await upsert_subscription(
|
||||
user_id=test_user["id"],
|
||||
plan="pro",
|
||||
status="active",
|
||||
provider_customer_id="cust_new",
|
||||
provider_subscription_id="sub_new",
|
||||
provider_subscription_id="sub_same",
|
||||
)
|
||||
assert returned_id == original_id
|
||||
row = await get_subscription(test_user["id"])
|
||||
assert row["plan"] == "pro"
|
||||
assert row["provider_subscription_id"] == "sub_same"
|
||||
|
||||
assert row["paddle_subscription_id"] == "sub_new"
|
||||
|
||||
async def test_different_provider_id_creates_new(self, db, test_user):
|
||||
"""Different provider_subscription_id creates a new row (multi-sub support)."""
|
||||
await upsert_subscription(
|
||||
user_id=test_user["id"],
|
||||
plan="starter",
|
||||
status="active",
|
||||
provider_subscription_id="sub_first",
|
||||
)
|
||||
await upsert_subscription(
|
||||
user_id=test_user["id"],
|
||||
plan="pro",
|
||||
status="active",
|
||||
provider_subscription_id="sub_second",
|
||||
)
|
||||
from beanflows.core import fetch_all
|
||||
rows = await fetch_all(
|
||||
"SELECT * FROM subscriptions WHERE user_id = ? ORDER BY created_at",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert len(rows) == 2
|
||||
|
||||
async def test_upsert_with_none_period_end(self, db, test_user):
|
||||
await upsert_subscription(
|
||||
user_id=test_user["id"],
|
||||
plan="pro",
|
||||
status="active",
|
||||
provider_customer_id="cust_1",
|
||||
provider_subscription_id="sub_1",
|
||||
current_period_end=None,
|
||||
)
|
||||
@@ -93,6 +109,28 @@ class TestUpsertSubscription:
|
||||
assert row["current_period_end"] is None
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# upsert_billing_customer / get_billing_customer
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestUpsertBillingCustomer:
|
||||
async def test_creates_billing_customer(self, db, test_user):
|
||||
await upsert_billing_customer(test_user["id"], "cust_abc")
|
||||
row = await get_billing_customer(test_user["id"])
|
||||
assert row is not None
|
||||
assert row["provider_customer_id"] == "cust_abc"
|
||||
|
||||
async def test_updates_existing_customer(self, db, test_user):
|
||||
await upsert_billing_customer(test_user["id"], "cust_old")
|
||||
await upsert_billing_customer(test_user["id"], "cust_new")
|
||||
row = await get_billing_customer(test_user["id"])
|
||||
assert row["provider_customer_id"] == "cust_new"
|
||||
|
||||
async def test_get_returns_none_for_unknown_user(self, db):
|
||||
row = await get_billing_customer(99999)
|
||||
assert row is None
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# get_subscription_by_provider_id
|
||||
# ════════════════════════════════════════════════════════════
|
||||
@@ -102,10 +140,8 @@ class TestGetSubscriptionByProviderId:
|
||||
result = await get_subscription_by_provider_id("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
async def test_finds_by_paddle_subscription_id(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_findme")
|
||||
|
||||
async def test_finds_by_provider_subscription_id(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_findme")
|
||||
result = await get_subscription_by_provider_id("sub_findme")
|
||||
assert result is not None
|
||||
assert result["user_id"] == test_user["id"]
|
||||
@@ -117,18 +153,14 @@ class TestGetSubscriptionByProviderId:
|
||||
|
||||
class TestUpdateSubscriptionStatus:
|
||||
async def test_updates_status(self, db, test_user, create_subscription):
|
||||
|
||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_upd")
|
||||
|
||||
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_upd")
|
||||
await update_subscription_status("sub_upd", status="cancelled")
|
||||
row = await get_subscription(test_user["id"])
|
||||
assert row["status"] == "cancelled"
|
||||
assert row["updated_at"] is not None
|
||||
|
||||
async def test_updates_extra_fields(self, db, test_user, create_subscription):
|
||||
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_extra")
|
||||
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_extra")
|
||||
await update_subscription_status(
|
||||
"sub_extra",
|
||||
status="active",
|
||||
@@ -141,9 +173,7 @@ class TestUpdateSubscriptionStatus:
|
||||
assert row["current_period_end"] == "2026-01-01T00:00:00Z"
|
||||
|
||||
async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription):
|
||||
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_known", status="active")
|
||||
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_known", status="active")
|
||||
await update_subscription_status("sub_unknown", status="expired")
|
||||
row = await get_subscription(test_user["id"])
|
||||
assert row["status"] == "active" # unchanged
|
||||
@@ -155,22 +185,22 @@ class TestUpdateSubscriptionStatus:
|
||||
|
||||
class TestCanAccessFeature:
|
||||
async def test_no_subscription_gets_free_features(self, db, test_user):
|
||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
||||
assert await can_access_feature(test_user["id"], "basic") is True
|
||||
assert await can_access_feature(test_user["id"], "export") is False
|
||||
assert await can_access_feature(test_user["id"], "api") is False
|
||||
|
||||
async def test_active_pro_gets_all_features(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="active")
|
||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
||||
assert await can_access_feature(test_user["id"], "basic") is True
|
||||
assert await can_access_feature(test_user["id"], "export") is True
|
||||
assert await can_access_feature(test_user["id"], "api") is True
|
||||
assert await can_access_feature(test_user["id"], "priority_support") is True
|
||||
|
||||
async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="starter", status="active")
|
||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
||||
assert await can_access_feature(test_user["id"], "basic") is True
|
||||
assert await can_access_feature(test_user["id"], "export") is True
|
||||
assert await can_access_feature(test_user["id"], "all_commodities") is False
|
||||
assert await can_access_feature(test_user["id"], "api") is False
|
||||
|
||||
async def test_cancelled_still_has_features(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="cancelled")
|
||||
@@ -183,7 +213,7 @@ class TestCanAccessFeature:
|
||||
async def test_expired_falls_back_to_free(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="expired")
|
||||
assert await can_access_feature(test_user["id"], "api") is False
|
||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
||||
assert await can_access_feature(test_user["id"], "basic") is True
|
||||
|
||||
async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="past_due")
|
||||
@@ -203,30 +233,28 @@ class TestCanAccessFeature:
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestIsWithinLimits:
|
||||
async def test_free_user_no_api_calls(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
|
||||
async def test_free_user_within_limits(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "items", 50) is True
|
||||
|
||||
async def test_free_user_commodity_limit(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "commodities", 0) is True
|
||||
assert await is_within_limits(test_user["id"], "commodities", 1) is False
|
||||
async def test_free_user_at_limit(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "items", 100) is False
|
||||
|
||||
async def test_free_user_history_limit(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "history_years", 4) is True
|
||||
assert await is_within_limits(test_user["id"], "history_years", 5) is False
|
||||
async def test_free_user_over_limit(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "items", 150) is False
|
||||
|
||||
async def test_pro_unlimited(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="active")
|
||||
assert await is_within_limits(test_user["id"], "commodities", 999999) is True
|
||||
assert await is_within_limits(test_user["id"], "items", 999999) is True
|
||||
assert await is_within_limits(test_user["id"], "api_calls", 999999) is True
|
||||
|
||||
async def test_starter_limits(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="starter", status="active")
|
||||
assert await is_within_limits(test_user["id"], "api_calls", 9999) is True
|
||||
assert await is_within_limits(test_user["id"], "api_calls", 10000) is False
|
||||
assert await is_within_limits(test_user["id"], "items", 999) is True
|
||||
assert await is_within_limits(test_user["id"], "items", 1000) is False
|
||||
|
||||
async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], plan="pro", status="expired")
|
||||
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
|
||||
assert await is_within_limits(test_user["id"], "items", 100) is False
|
||||
|
||||
async def test_unknown_resource_returns_false(self, db, test_user):
|
||||
assert await is_within_limits(test_user["id"], "unicorns", 0) is False
|
||||
@@ -238,7 +266,7 @@ class TestIsWithinLimits:
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"]
|
||||
FEATURES = ["dashboard", "export", "api", "priority_support"]
|
||||
FEATURES = ["basic", "export", "api", "priority_support"]
|
||||
ACTIVE_STATUSES = {"active", "on_trial", "cancelled"}
|
||||
|
||||
|
||||
@@ -282,9 +310,9 @@ async def test_plan_feature_matrix(db, test_user, create_subscription, plan, fea
|
||||
|
||||
@pytest.mark.parametrize("plan", PLANS)
|
||||
@pytest.mark.parametrize("resource,at_limit", [
|
||||
("commodities", 1),
|
||||
("commodities", 65),
|
||||
("api_calls", 0),
|
||||
("items", 100),
|
||||
("items", 1000),
|
||||
("api_calls", 1000),
|
||||
("api_calls", 10000),
|
||||
])
|
||||
async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit):
|
||||
@@ -307,11 +335,11 @@ async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resou
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestLimitsHypothesis:
|
||||
@given(count=st.integers(min_value=0, max_value=100))
|
||||
@given(count=st.integers(min_value=0, max_value=10000))
|
||||
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
async def test_free_limit_boundary_commodities(self, db, test_user, count):
|
||||
result = await is_within_limits(test_user["id"], "commodities", count)
|
||||
assert result == (count < 1)
|
||||
async def test_free_limit_boundary_items(self, db, test_user, count):
|
||||
result = await is_within_limits(test_user["id"], "items", count)
|
||||
assert result == (count < 100)
|
||||
|
||||
@given(count=st.integers(min_value=0, max_value=100000))
|
||||
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@@ -319,7 +347,56 @@ class TestLimitsHypothesis:
|
||||
# Use upsert to avoid duplicate inserts across Hypothesis examples
|
||||
await upsert_subscription(
|
||||
user_id=test_user["id"], plan="pro", status="active",
|
||||
provider_customer_id="cust_hyp", provider_subscription_id="sub_hyp",
|
||||
provider_subscription_id="sub_hyp",
|
||||
)
|
||||
result = await is_within_limits(test_user["id"], "commodities", count)
|
||||
result = await is_within_limits(test_user["id"], "items", count)
|
||||
assert result is True
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# record_transaction
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestRecordTransaction:
|
||||
async def test_inserts_transaction(self, db, test_user):
|
||||
txn_id = await record_transaction(
|
||||
user_id=test_user["id"],
|
||||
provider_transaction_id="txn_abc123",
|
||||
type="payment",
|
||||
amount_cents=2999,
|
||||
currency="EUR",
|
||||
status="completed",
|
||||
)
|
||||
assert txn_id is not None and txn_id > 0
|
||||
|
||||
from beanflows.core import fetch_one
|
||||
row = await fetch_one(
|
||||
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
|
||||
("txn_abc123",),
|
||||
)
|
||||
assert row is not None
|
||||
assert row["user_id"] == test_user["id"]
|
||||
assert row["amount_cents"] == 2999
|
||||
assert row["currency"] == "EUR"
|
||||
assert row["status"] == "completed"
|
||||
|
||||
async def test_idempotent_on_duplicate_provider_id(self, db, test_user):
|
||||
await record_transaction(
|
||||
user_id=test_user["id"],
|
||||
provider_transaction_id="txn_dup",
|
||||
amount_cents=1000,
|
||||
)
|
||||
# Second insert with same provider_transaction_id should be ignored
|
||||
await record_transaction(
|
||||
user_id=test_user["id"],
|
||||
provider_transaction_id="txn_dup",
|
||||
amount_cents=9999,
|
||||
)
|
||||
|
||||
from beanflows.core import fetch_all
|
||||
rows = await fetch_all(
|
||||
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
|
||||
("txn_dup",),
|
||||
)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["amount_cents"] == 1000 # original value preserved
|
||||
|
||||
122
web/tests/test_billing_hooks.py
Normal file
122
web/tests/test_billing_hooks.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Tests for the billing event hook system.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from beanflows.billing.routes import _billing_hooks, _fire_hooks, on_billing_event
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_hooks():
|
||||
"""Ensure hooks are clean before and after each test."""
|
||||
_billing_hooks.clear()
|
||||
yield
|
||||
_billing_hooks.clear()
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Registration
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestOnBillingEvent:
|
||||
def test_registers_single_event(self):
|
||||
@on_billing_event("subscription.activated")
|
||||
async def my_hook(event_type, data):
|
||||
pass
|
||||
|
||||
assert "subscription.activated" in _billing_hooks
|
||||
assert my_hook in _billing_hooks["subscription.activated"]
|
||||
|
||||
def test_registers_multiple_events(self):
|
||||
@on_billing_event("subscription.activated", "subscription.updated")
|
||||
async def my_hook(event_type, data):
|
||||
pass
|
||||
|
||||
assert my_hook in _billing_hooks["subscription.activated"]
|
||||
assert my_hook in _billing_hooks["subscription.updated"]
|
||||
|
||||
def test_multiple_hooks_per_event(self):
|
||||
@on_billing_event("subscription.activated")
|
||||
async def hook_a(event_type, data):
|
||||
pass
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def hook_b(event_type, data):
|
||||
pass
|
||||
|
||||
assert len(_billing_hooks["subscription.activated"]) == 2
|
||||
|
||||
def test_decorator_returns_original_function(self):
|
||||
@on_billing_event("test_event")
|
||||
async def my_hook(event_type, data):
|
||||
pass
|
||||
|
||||
assert my_hook.__name__ == "my_hook"
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Firing
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestFireHooks:
|
||||
async def test_fires_registered_hook(self):
|
||||
calls = []
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def recorder(event_type, data):
|
||||
calls.append((event_type, data))
|
||||
|
||||
await _fire_hooks("subscription.activated", {"id": "sub_123"})
|
||||
assert len(calls) == 1
|
||||
assert calls[0] == ("subscription.activated", {"id": "sub_123"})
|
||||
|
||||
async def test_no_hooks_registered_is_noop(self):
|
||||
# Should not raise
|
||||
await _fire_hooks("unregistered_event", {"id": "sub_123"})
|
||||
|
||||
async def test_fires_all_hooks_for_event(self):
|
||||
calls = []
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def hook_a(event_type, data):
|
||||
calls.append("a")
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def hook_b(event_type, data):
|
||||
calls.append("b")
|
||||
|
||||
await _fire_hooks("subscription.activated", {})
|
||||
assert calls == ["a", "b"]
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Error isolation
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestHookErrorIsolation:
|
||||
async def test_failing_hook_does_not_block_others(self):
|
||||
calls = []
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def failing_hook(event_type, data):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
@on_billing_event("subscription.activated")
|
||||
async def good_hook(event_type, data):
|
||||
calls.append("ok")
|
||||
|
||||
# Should not raise despite first hook failing
|
||||
await _fire_hooks("subscription.activated", {})
|
||||
assert calls == ["ok"]
|
||||
|
||||
async def test_failing_hook_is_logged(self, caplog):
|
||||
@on_billing_event("subscription.activated")
|
||||
async def bad_hook(event_type, data):
|
||||
raise ValueError("test error")
|
||||
|
||||
import logging
|
||||
with caplog.at_level(logging.ERROR):
|
||||
await _fire_hooks("subscription.activated", {})
|
||||
|
||||
assert "bad_hook" in caplog.text
|
||||
assert "test error" in caplog.text
|
||||
@@ -1,12 +1,15 @@
|
||||
"""
|
||||
Route integration tests for Paddle billing endpoints.
|
||||
External Paddle API calls mocked with respx.
|
||||
"""
|
||||
import json
|
||||
|
||||
import httpx
|
||||
Paddle SDK calls mocked via mock_paddle_client fixture.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
|
||||
CHECKOUT_METHOD = "POST"
|
||||
@@ -54,24 +57,16 @@ class TestCheckoutRoute:
|
||||
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
@respx.mock
|
||||
async def test_creates_checkout_session(self, auth_client, db, test_user):
|
||||
|
||||
respx.post("https://api.paddle.com/transactions").mock(
|
||||
return_value=httpx.Response(200, json={
|
||||
"data": {
|
||||
"checkout": {
|
||||
"url": "https://checkout.paddle.com/test_123"
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
|
||||
async def test_creates_checkout_session(self, auth_client, db, test_user, mock_paddle_client):
|
||||
mock_txn = MagicMock()
|
||||
mock_txn.checkout.url = "https://checkout.paddle.com/test_123"
|
||||
mock_paddle_client.transactions.create.return_value = mock_txn
|
||||
|
||||
response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False)
|
||||
|
||||
assert response.status_code in (302, 303, 307)
|
||||
mock_paddle_client.transactions.create.assert_called_once()
|
||||
|
||||
|
||||
async def test_invalid_plan_rejected(self, auth_client, db, test_user):
|
||||
|
||||
@@ -82,20 +77,13 @@ class TestCheckoutRoute:
|
||||
|
||||
|
||||
|
||||
@respx.mock
|
||||
async def test_api_error_propagates(self, auth_client, db, test_user):
|
||||
|
||||
respx.post("https://api.paddle.com/transactions").mock(
|
||||
return_value=httpx.Response(500, json={"error": "server error"})
|
||||
)
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
|
||||
async def test_api_error_propagates(self, auth_client, db, test_user, mock_paddle_client):
|
||||
mock_paddle_client.transactions.create.side_effect = Exception("API error")
|
||||
with pytest.raises(Exception, match="API error"):
|
||||
await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}")
|
||||
|
||||
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Manage subscription / Portal
|
||||
# ════════════════════════════════════════════════════════════
|
||||
@@ -110,24 +98,18 @@ class TestManageRoute:
|
||||
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
@respx.mock
|
||||
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription):
|
||||
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
|
||||
|
||||
respx.get("https://api.paddle.com/subscriptions/sub_test").mock(
|
||||
return_value=httpx.Response(200, json={
|
||||
"data": {
|
||||
"management_urls": {
|
||||
"update_payment_method": "https://paddle.com/manage/test_123"
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
|
||||
|
||||
mock_sub = MagicMock()
|
||||
mock_sub.management_urls.update_payment_method = "https://paddle.com/manage/test_123"
|
||||
mock_paddle_client.subscriptions.get.return_value = mock_sub
|
||||
|
||||
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
mock_paddle_client.subscriptions.get.assert_called_once_with("sub_test")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -145,18 +127,14 @@ class TestCancelRoute:
|
||||
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
@respx.mock
|
||||
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription):
|
||||
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
|
||||
|
||||
respx.post("https://api.paddle.com/subscriptions/sub_test/cancel").mock(
|
||||
return_value=httpx.Response(200, json={"data": {}})
|
||||
)
|
||||
|
||||
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
|
||||
|
||||
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
mock_paddle_client.subscriptions.cancel.assert_called_once()
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -167,8 +145,9 @@ class TestCancelRoute:
|
||||
# subscription_required decorator
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
from beanflows.billing.routes import subscription_required
|
||||
from quart import Blueprint
|
||||
from quart import Blueprint # noqa: E402
|
||||
|
||||
from beanflows.auth.routes import subscription_required # noqa: E402
|
||||
|
||||
test_bp = Blueprint("test", __name__)
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ Covers signature verification, event parsing, subscription lifecycle transitions
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from conftest import make_webhook_payload, sign_payload
|
||||
|
||||
from hypothesis import HealthCheck, given
|
||||
from hypothesis import settings as h_settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from beanflows.billing.routes import get_subscription
|
||||
|
||||
from conftest import make_webhook_payload, sign_payload
|
||||
from beanflows.billing.routes import get_billing_customer, get_subscription
|
||||
|
||||
|
||||
WEBHOOK_PATH = "/billing/webhook/paddle"
|
||||
@@ -72,18 +72,19 @@ class TestWebhookSignature:
|
||||
|
||||
async def test_modified_payload_rejected(self, client, db, test_user):
|
||||
|
||||
# Paddle SDK Verifier handles tamper detection internally.
|
||||
# We test signature rejection via test_invalid_signature_rejected above.
|
||||
# This test verifies the Verifier is actually called by sending
|
||||
# a payload with an explicitly bad signature.
|
||||
payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"]))
|
||||
payload_bytes = json.dumps(payload).encode()
|
||||
sig = sign_payload(payload_bytes)
|
||||
tampered = payload_bytes + b"extra"
|
||||
|
||||
# Paddle/LemonSqueezy: HMAC signature verification fails before JSON parsing
|
||||
response = await client.post(
|
||||
WEBHOOK_PATH,
|
||||
data=tampered,
|
||||
headers={SIG_HEADER: sig, "Content-Type": "application/json"},
|
||||
data=payload_bytes,
|
||||
headers={SIG_HEADER: "invalid_signature", "Content-Type": "application/json"},
|
||||
)
|
||||
assert response.status_code in (400, 401)
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
async def test_empty_payload_rejected(self, client, db):
|
||||
@@ -105,7 +106,7 @@ class TestWebhookSignature:
|
||||
|
||||
|
||||
class TestWebhookSubscriptionActivated:
|
||||
async def test_creates_subscription(self, client, db, test_user):
|
||||
async def test_creates_subscription_and_billing_customer(self, client, db, test_user):
|
||||
payload = make_webhook_payload(
|
||||
"subscription.activated",
|
||||
user_id=str(test_user["id"]),
|
||||
@@ -126,10 +127,14 @@ class TestWebhookSubscriptionActivated:
|
||||
assert sub["plan"] == "starter"
|
||||
assert sub["status"] == "active"
|
||||
|
||||
bc = await get_billing_customer(test_user["id"])
|
||||
assert bc is not None
|
||||
assert bc["provider_customer_id"] == "ctm_test123"
|
||||
|
||||
|
||||
class TestWebhookSubscriptionUpdated:
|
||||
async def test_updates_subscription_status(self, client, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
||||
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||
|
||||
payload = make_webhook_payload(
|
||||
"subscription.updated",
|
||||
@@ -152,7 +157,7 @@ class TestWebhookSubscriptionUpdated:
|
||||
|
||||
class TestWebhookSubscriptionCanceled:
|
||||
async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
||||
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||
|
||||
payload = make_webhook_payload(
|
||||
"subscription.canceled",
|
||||
@@ -174,7 +179,7 @@ class TestWebhookSubscriptionCanceled:
|
||||
|
||||
class TestWebhookSubscriptionPastDue:
|
||||
async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription):
|
||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
||||
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||
|
||||
payload = make_webhook_payload(
|
||||
"subscription.past_due",
|
||||
@@ -209,7 +214,7 @@ class TestWebhookSubscriptionPastDue:
|
||||
])
|
||||
async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status):
|
||||
if event_type != "subscription.activated":
|
||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test456")
|
||||
await create_subscription(test_user["id"], provider_subscription_id="sub_test456")
|
||||
|
||||
payload = make_webhook_payload(event_type, user_id=str(test_user["id"]))
|
||||
payload_bytes = json.dumps(payload).encode()
|
||||
|
||||
242
web/tests/test_roles.py
Normal file
242
web/tests/test_roles.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
Tests for role-based access control: role_required decorator, grant/revoke/ensure_admin_role,
|
||||
and admin route protection.
|
||||
"""
|
||||
import pytest
|
||||
from quart import Blueprint
|
||||
|
||||
from beanflows.auth.routes import (
|
||||
ensure_admin_role,
|
||||
grant_role,
|
||||
revoke_role,
|
||||
role_required,
|
||||
)
|
||||
from beanflows import core
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# grant_role / revoke_role
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestGrantRole:
|
||||
async def test_grants_role(self, db, test_user):
|
||||
await grant_role(test_user["id"], "admin")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ?",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is not None
|
||||
assert row["role"] == "admin"
|
||||
|
||||
async def test_idempotent(self, db, test_user):
|
||||
await grant_role(test_user["id"], "admin")
|
||||
await grant_role(test_user["id"], "admin")
|
||||
rows = await core.fetch_all(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert len(rows) == 1
|
||||
|
||||
|
||||
class TestRevokeRole:
|
||||
async def test_revokes_existing_role(self, db, test_user):
|
||||
await grant_role(test_user["id"], "admin")
|
||||
await revoke_role(test_user["id"], "admin")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is None
|
||||
|
||||
async def test_noop_for_missing_role(self, db, test_user):
|
||||
# Should not raise
|
||||
await revoke_role(test_user["id"], "nonexistent")
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# ensure_admin_role
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestEnsureAdminRole:
|
||||
async def test_grants_admin_for_listed_email(self, db, test_user):
|
||||
core.config.ADMIN_EMAILS = ["test@example.com"]
|
||||
try:
|
||||
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is not None
|
||||
finally:
|
||||
core.config.ADMIN_EMAILS = []
|
||||
|
||||
async def test_skips_for_unlisted_email(self, db, test_user):
|
||||
core.config.ADMIN_EMAILS = ["boss@example.com"]
|
||||
try:
|
||||
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is None
|
||||
finally:
|
||||
core.config.ADMIN_EMAILS = []
|
||||
|
||||
async def test_empty_admin_emails_grants_nothing(self, db, test_user):
|
||||
core.config.ADMIN_EMAILS = []
|
||||
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is None
|
||||
|
||||
async def test_case_insensitive_matching(self, db, test_user):
|
||||
core.config.ADMIN_EMAILS = ["test@example.com"]
|
||||
try:
|
||||
await ensure_admin_role(test_user["id"], "Test@Example.COM")
|
||||
row = await core.fetch_one(
|
||||
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||
(test_user["id"],),
|
||||
)
|
||||
assert row is not None
|
||||
finally:
|
||||
core.config.ADMIN_EMAILS = []
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# role_required decorator
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
role_test_bp = Blueprint("role_test", __name__)
|
||||
|
||||
|
||||
@role_test_bp.route("/admin-only")
|
||||
@role_required("admin")
|
||||
async def admin_only_route():
|
||||
return "admin-ok", 200
|
||||
|
||||
|
||||
@role_test_bp.route("/multi-role")
|
||||
@role_required("admin", "editor")
|
||||
async def multi_role_route():
|
||||
return "multi-ok", 200
|
||||
|
||||
|
||||
class TestRoleRequired:
|
||||
@pytest.fixture
|
||||
async def role_app(self, app):
|
||||
app.register_blueprint(role_test_bp)
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
async def role_client(self, role_app):
|
||||
async with role_app.test_client() as c:
|
||||
yield c
|
||||
|
||||
async def test_redirects_unauthenticated(self, role_client, db):
|
||||
response = await role_client.get("/admin-only", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async def test_rejects_user_without_role(self, role_client, db, test_user):
|
||||
async with role_client.session_transaction() as sess:
|
||||
sess["user_id"] = test_user["id"]
|
||||
|
||||
response = await role_client.get("/admin-only", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async def test_allows_user_with_matching_role(self, role_client, db, test_user):
|
||||
await grant_role(test_user["id"], "admin")
|
||||
async with role_client.session_transaction() as sess:
|
||||
sess["user_id"] = test_user["id"]
|
||||
|
||||
response = await role_client.get("/admin-only")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_multi_role_allows_any_match(self, role_client, db, test_user):
|
||||
await grant_role(test_user["id"], "editor")
|
||||
async with role_client.session_transaction() as sess:
|
||||
sess["user_id"] = test_user["id"]
|
||||
|
||||
response = await role_client.get("/multi-role")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_multi_role_rejects_none(self, role_client, db, test_user):
|
||||
await grant_role(test_user["id"], "viewer")
|
||||
async with role_client.session_transaction() as sess:
|
||||
sess["user_id"] = test_user["id"]
|
||||
|
||||
response = await role_client.get("/multi-role", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Admin route protection
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestAdminRouteProtection:
|
||||
async def test_admin_index_requires_admin_role(self, auth_client, db):
|
||||
response = await auth_client.get("/admin/", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async def test_admin_index_accessible_with_admin_role(self, admin_client, db):
|
||||
response = await admin_client.get("/admin/")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_admin_users_requires_admin_role(self, auth_client, db):
|
||||
response = await auth_client.get("/admin/users", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async def test_admin_tasks_requires_admin_role(self, auth_client, db):
|
||||
response = await auth_client.get("/admin/tasks", follow_redirects=False)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
|
||||
# ════════════════════════════════════════════════════════════
|
||||
# Impersonation
|
||||
# ════════════════════════════════════════════════════════════
|
||||
|
||||
class TestImpersonation:
|
||||
async def test_impersonate_stores_admin_id(self, admin_client, db, test_user):
|
||||
"""Impersonating stores admin's user_id in session['admin_impersonating']."""
|
||||
# Create a second user to impersonate
|
||||
now = "2025-01-01T00:00:00"
|
||||
other_id = await core.execute(
|
||||
"INSERT INTO users (email, name, created_at) VALUES (?, ?, ?)",
|
||||
("other@example.com", "Other", now),
|
||||
)
|
||||
|
||||
async with admin_client.session_transaction() as sess:
|
||||
sess["csrf_token"] = "test_csrf"
|
||||
|
||||
response = await admin_client.post(
|
||||
f"/admin/users/{other_id}/impersonate",
|
||||
form={"csrf_token": "test_csrf"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async with admin_client.session_transaction() as sess:
|
||||
assert sess["user_id"] == other_id
|
||||
assert sess["admin_impersonating"] == test_user["id"]
|
||||
|
||||
async def test_stop_impersonating_restores_admin(self, app, db, test_user, grant_role):
|
||||
"""Stopping impersonation restores the admin's user_id."""
|
||||
await grant_role(test_user["id"], "admin")
|
||||
|
||||
async with app.test_client() as c:
|
||||
async with c.session_transaction() as sess:
|
||||
sess["user_id"] = 999 # impersonated user
|
||||
sess["admin_impersonating"] = test_user["id"]
|
||||
sess["csrf_token"] = "test_csrf"
|
||||
|
||||
response = await c.post(
|
||||
"/admin/stop-impersonating",
|
||||
form={"csrf_token": "test_csrf"},
|
||||
follow_redirects=False,
|
||||
)
|
||||
assert response.status_code in (302, 303, 307)
|
||||
|
||||
async with c.session_transaction() as sess:
|
||||
assert sess["user_id"] == test_user["id"]
|
||||
assert "admin_impersonating" not in sess
|
||||
Reference in New Issue
Block a user