Update from Copier template v0.4.0
- Accept RBAC system: user_roles table, role_required decorator, grant_role/revoke_role/ensure_admin_role functions - Accept improved billing architecture: billing_customers table separation, provider-agnostic naming - Accept enhanced user loading with subscription/roles eager loading in app.py - Accept improved email templates with branded styling - Accept new infrastructure: migration tracking, transaction logging, A/B testing - Accept template improvements: Resend SDK, Tailwind build stage, UMAMI analytics config - Keep beanflows-specific configs: BASE_URL 5001, coffee PLAN_FEATURES/PLAN_LIMITS - Keep beanflows analytics integration and DuckDB health check - Add new test files and utility scripts from template Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This commit is contained in:
73
web/CHANGELOG.md
Normal file
73
web/CHANGELOG.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
All notable changes to this project will be documented in this file.
|
||||||
|
|
||||||
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Role-based access control**: `user_roles` table with `role_required()` decorator replaces password-based admin auth
|
||||||
|
- **Admin is a real user**: admins authenticate via magic links; `ADMIN_EMAILS` env var auto-grants admin role on login
|
||||||
|
- **Separated billing entities**: `billing_customers` table holds payment provider identity; `subscriptions` table holds only subscription state
|
||||||
|
- **Multiple subscriptions per user**: dropped UNIQUE constraint on `subscriptions.user_id`; `upsert_subscription` finds by `provider_subscription_id`
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Simple A/B testing with `@ab_test` decorator and optional Umami `data-tag` integration (`UMAMI_SCRIPT_URL` / `UMAMI_WEBSITE_ID` env vars)
|
||||||
|
- `user_roles` table and `grant_role()` / `revoke_role()` / `ensure_admin_role()` functions
|
||||||
|
- `billing_customers` table and `upsert_billing_customer()` / `get_billing_customer()` functions
|
||||||
|
- `role_required(*roles)` decorator in auth
|
||||||
|
- `is_admin` template context variable
|
||||||
|
- Migration `0001_roles_and_billing_customers.py` for existing databases
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
- `ADMIN_PASSWORD` env var and password-based admin login
|
||||||
|
- `provider_customer_id` column from `subscriptions` table
|
||||||
|
- `admin/templates/admin/login.html`
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Provider-agnostic schema**: generic `provider_customer_id` / `provider_subscription_id` columns replace provider-prefixed names (`stripe_customer_id`, `paddle_customer_id`, `lemonsqueezy_customer_id`) — eliminates all Jinja conditionals from schema, SQL helpers, and route code
|
||||||
|
- **Consolidated `subscription_required` decorator**: single implementation in `auth/routes.py` supporting both plan and status checks, reads from eager-loaded `g.subscription` (zero extra queries)
|
||||||
|
- **Eager-loaded `g.subscription`**: `load_user` in `app.py` now fetches user + subscription in a single JOIN; available in all routes and templates via `g.subscription`
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- `transactions` table for recording payment/refund events with idempotent `record_transaction()` helper
|
||||||
|
- Billing event hook system: `on_billing_event()` decorator and `_fire_hooks()` for domain code to react to subscription changes; errors are logged and never cause webhook 500s
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
- Duplicate `subscription_required` decorator from `billing/routes.py` (consolidated in `auth/routes.py`)
|
||||||
|
- `get_user_with_subscription()` from `auth/routes.py` (replaced by eager-loaded `g.subscription`)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Email SDK migration**: replaced raw httpx calls with official `resend` SDK in `core.py`
|
||||||
|
- Added `from_addr` parameter to `send_email()` for multi-address support
|
||||||
|
- Added `EMAIL_ADDRESSES` dict for named sender addresses (transactional, etc.)
|
||||||
|
- **Paddle SDK migration**: replaced raw httpx calls with official `paddle-python-sdk` in `billing/routes.py`
|
||||||
|
- Checkout, manage, cancel routes now use typed SDK methods (`PaddleClient`, `CreateTransaction`)
|
||||||
|
- Webhook verification uses SDK's `Verifier` instead of hand-rolled HMAC
|
||||||
|
- Added `PADDLE_ENVIRONMENT` config for sandbox/production toggling
|
||||||
|
- Added `_paddle_client()` helper factory
|
||||||
|
- **Dependencies**: `resend` replaces `httpx` for email; `paddle-python-sdk` replaces `httpx` for Paddle billing; `httpx` now only included for LemonSqueezy projects
|
||||||
|
- Worker `send_email` task handler now passes through `from_addr`
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- `scripts/setup_paddle.py` — CLI script to create Paddle products/prices programmatically (Paddle projects only)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- **Pico CSS → Tailwind CSS v4** — full design system migration across all templates
|
||||||
|
- Standalone Tailwind CLI binary (no Node.js) with `make css-build` / `make css-watch`
|
||||||
|
- Brand theme with component classes (`.btn`, `.card`, `.form-input`, `.table`, `.badge`, `.flash`, etc.)
|
||||||
|
- Self-hosted Commit Mono font for monospace data display
|
||||||
|
- Docker multi-stage build: CSS compiled in dedicated stage before Python build
|
||||||
|
|
||||||
|
### Removed
|
||||||
|
- Pico CSS CDN dependency
|
||||||
|
- `custom.css` (replaced by Tailwind `input.css` with `@layer components`)
|
||||||
|
- JetBrains Mono font (replaced by self-hosted Commit Mono)
|
||||||
|
|
||||||
|
### Fixed
|
||||||
|
- Admin template collision: namespaced admin templates under `admin/` subdirectory to prevent Quart's template loader from resolving auth's `login.html` or dashboard's `index.html` instead of admin's
|
||||||
|
- Admin user detail: `stripe_customer_id` hardcoded regardless of payment provider — now uses provider-aware Copier conditional (Stripe/Paddle/LemonSqueezy)
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- Initial project scaffolded from quart_saas_boilerplate
|
||||||
@@ -1,3 +1,13 @@
|
|||||||
|
# CSS build stage (Tailwind standalone CLI, no Node.js)
|
||||||
|
FROM debian:bookworm-slim AS css-build
|
||||||
|
ADD https://github.com/tailwindlabs/tailwindcss/releases/latest/download/tailwindcss-linux-x64 /usr/local/bin/tailwindcss
|
||||||
|
RUN chmod +x /usr/local/bin/tailwindcss
|
||||||
|
WORKDIR /app
|
||||||
|
COPY src/ ./src/
|
||||||
|
RUN tailwindcss -i ./src/beanflows/static/css/input.css \
|
||||||
|
-o ./src/beanflows/static/css/output.css --minify
|
||||||
|
|
||||||
|
|
||||||
# Build stage
|
# Build stage
|
||||||
FROM python:3.12-slim AS build
|
FROM python:3.12-slim AS build
|
||||||
COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/
|
COPY --from=ghcr.io/astral-sh/uv:0.8 /uv /uvx /bin/
|
||||||
@@ -15,6 +25,7 @@ RUN useradd -m -u 1000 appuser
|
|||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
RUN mkdir -p /app/data && chown -R appuser:appuser /app
|
RUN mkdir -p /app/data && chown -R appuser:appuser /app
|
||||||
COPY --from=build --chown=appuser:appuser /app .
|
COPY --from=build --chown=appuser:appuser /app .
|
||||||
|
COPY --from=css-build /app/src/beanflows/static/css/output.css ./src/beanflows/static/css/output.css
|
||||||
USER appuser
|
USER appuser
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
ENV DATABASE_PATH=/app/data/app.db
|
ENV DATABASE_PATH=/app/data/app.db
|
||||||
|
|||||||
@@ -12,8 +12,9 @@ dependencies = [
|
|||||||
"aiosqlite>=0.19.0",
|
"aiosqlite>=0.19.0",
|
||||||
"duckdb>=1.0.0",
|
"duckdb>=1.0.0",
|
||||||
"httpx>=0.27.0",
|
"httpx>=0.27.0",
|
||||||
|
"resend>=2.22.0",
|
||||||
"python-dotenv>=1.0.0",
|
"python-dotenv>=1.0.0",
|
||||||
|
"paddle-python-sdk>=1.13.0",
|
||||||
"itsdangerous>=2.1.0",
|
"itsdangerous>=2.1.0",
|
||||||
"jinja2>=3.1.0",
|
"jinja2>=3.1.0",
|
||||||
"hypercorn>=0.17.0",
|
"hypercorn>=0.17.0",
|
||||||
|
|||||||
@@ -12,24 +12,24 @@ from .core import close_db, config, get_csrf_token, init_db, setup_request_id
|
|||||||
|
|
||||||
def create_app() -> Quart:
|
def create_app() -> Quart:
|
||||||
"""Create and configure the Quart application."""
|
"""Create and configure the Quart application."""
|
||||||
|
|
||||||
# Get package directory for templates
|
# Get package directory for templates
|
||||||
pkg_dir = Path(__file__).parent
|
pkg_dir = Path(__file__).parent
|
||||||
|
|
||||||
app = Quart(
|
app = Quart(
|
||||||
__name__,
|
__name__,
|
||||||
template_folder=str(pkg_dir / "templates"),
|
template_folder=str(pkg_dir / "templates"),
|
||||||
static_folder=str(pkg_dir / "static"),
|
static_folder=str(pkg_dir / "static"),
|
||||||
)
|
)
|
||||||
|
|
||||||
app.secret_key = config.SECRET_KEY
|
app.secret_key = config.SECRET_KEY
|
||||||
|
|
||||||
# Session config
|
# Session config
|
||||||
app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG
|
app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG
|
||||||
app.config["SESSION_COOKIE_HTTPONLY"] = True
|
app.config["SESSION_COOKIE_HTTPONLY"] = True
|
||||||
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
||||||
app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS
|
app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS
|
||||||
|
|
||||||
# Database lifecycle
|
# Database lifecycle
|
||||||
@app.before_serving
|
@app.before_serving
|
||||||
async def startup():
|
async def startup():
|
||||||
@@ -41,7 +41,7 @@ def create_app() -> Quart:
|
|||||||
async def shutdown():
|
async def shutdown():
|
||||||
close_analytics_db()
|
close_analytics_db()
|
||||||
await close_db()
|
await close_db()
|
||||||
|
|
||||||
# Security headers
|
# Security headers
|
||||||
@app.after_request
|
@app.after_request
|
||||||
async def add_security_headers(response):
|
async def add_security_headers(response):
|
||||||
@@ -51,16 +51,42 @@ def create_app() -> Quart:
|
|||||||
if not config.DEBUG:
|
if not config.DEBUG:
|
||||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Load current user before each request
|
# Load current user + subscription + roles before each request
|
||||||
@app.before_request
|
@app.before_request
|
||||||
async def load_user():
|
async def load_user():
|
||||||
g.user = None
|
g.user = None
|
||||||
|
g.subscription = None
|
||||||
user_id = session.get("user_id")
|
user_id = session.get("user_id")
|
||||||
if user_id:
|
if user_id:
|
||||||
from .auth.routes import get_user_by_id
|
from .core import fetch_one as _fetch_one
|
||||||
g.user = await get_user_by_id(user_id)
|
row = await _fetch_one(
|
||||||
|
"""SELECT u.*,
|
||||||
|
bc.provider_customer_id,
|
||||||
|
(SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv,
|
||||||
|
s.id AS sub_id, s.plan, s.status AS sub_status,
|
||||||
|
s.provider_subscription_id, s.current_period_end
|
||||||
|
FROM users u
|
||||||
|
LEFT JOIN billing_customers bc ON bc.user_id = u.id
|
||||||
|
LEFT JOIN subscriptions s ON s.id = (
|
||||||
|
SELECT id FROM subscriptions
|
||||||
|
WHERE user_id = u.id
|
||||||
|
ORDER BY created_at DESC LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE u.id = ? AND u.deleted_at IS NULL""",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
if row:
|
||||||
|
g.user = dict(row)
|
||||||
|
g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else []
|
||||||
|
if row["sub_id"]:
|
||||||
|
g.subscription = {
|
||||||
|
"id": row["sub_id"], "plan": row["plan"],
|
||||||
|
"status": row["sub_status"],
|
||||||
|
"provider_subscription_id": row["provider_subscription_id"],
|
||||||
|
"current_period_end": row["current_period_end"],
|
||||||
|
}
|
||||||
|
|
||||||
# Template context globals
|
# Template context globals
|
||||||
@app.context_processor
|
@app.context_processor
|
||||||
def inject_globals():
|
def inject_globals():
|
||||||
@@ -68,10 +94,14 @@ def create_app() -> Quart:
|
|||||||
return {
|
return {
|
||||||
"config": config,
|
"config": config,
|
||||||
"user": g.get("user"),
|
"user": g.get("user"),
|
||||||
|
"subscription": g.get("subscription"),
|
||||||
|
"is_admin": "admin" in (g.get("user") or {}).get("roles", []),
|
||||||
"now": datetime.utcnow(),
|
"now": datetime.utcnow(),
|
||||||
"csrf_token": get_csrf_token,
|
"csrf_token": get_csrf_token,
|
||||||
|
"ab_variant": getattr(g, "ab_variant", None),
|
||||||
|
"ab_tag": getattr(g, "ab_tag", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Health check
|
# Health check
|
||||||
@app.route("/health")
|
@app.route("/health")
|
||||||
async def health():
|
async def health():
|
||||||
@@ -94,7 +124,7 @@ def create_app() -> Quart:
|
|||||||
result["duckdb"] = "not configured"
|
result["duckdb"] = "not configured"
|
||||||
status_code = 200 if result["status"] == "healthy" else 500
|
status_code = 200 if result["status"] == "healthy" else 500
|
||||||
return result, status_code
|
return result, status_code
|
||||||
|
|
||||||
# Register blueprints
|
# Register blueprints
|
||||||
from .admin.routes import bp as admin_bp
|
from .admin.routes import bp as admin_bp
|
||||||
from .api.routes import bp as api_bp
|
from .api.routes import bp as api_bp
|
||||||
@@ -102,17 +132,17 @@ def create_app() -> Quart:
|
|||||||
from .billing.routes import bp as billing_bp
|
from .billing.routes import bp as billing_bp
|
||||||
from .dashboard.routes import bp as dashboard_bp
|
from .dashboard.routes import bp as dashboard_bp
|
||||||
from .public.routes import bp as public_bp
|
from .public.routes import bp as public_bp
|
||||||
|
|
||||||
app.register_blueprint(public_bp)
|
app.register_blueprint(public_bp)
|
||||||
app.register_blueprint(auth_bp)
|
app.register_blueprint(auth_bp)
|
||||||
app.register_blueprint(dashboard_bp)
|
app.register_blueprint(dashboard_bp)
|
||||||
app.register_blueprint(billing_bp)
|
app.register_blueprint(billing_bp)
|
||||||
app.register_blueprint(api_bp, url_prefix="/api/v1")
|
app.register_blueprint(api_bp, url_prefix="/api/v1")
|
||||||
app.register_blueprint(admin_bp)
|
app.register_blueprint(admin_bp)
|
||||||
|
|
||||||
# Request ID tracking
|
# Request ID tracking
|
||||||
setup_request_id(app)
|
setup_request_id(app)
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ async def get_valid_token(token: str) -> dict | None:
|
|||||||
"""Get token if valid and not expired."""
|
"""Get token if valid and not expired."""
|
||||||
return await fetch_one(
|
return await fetch_one(
|
||||||
"""
|
"""
|
||||||
SELECT at.*, u.email
|
SELECT at.*, u.email
|
||||||
FROM auth_tokens at
|
FROM auth_tokens at
|
||||||
JOIN users u ON u.id = at.user_id
|
JOIN users u ON u.id = at.user_id
|
||||||
WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL
|
WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL
|
||||||
@@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_user_with_subscription(user_id: int) -> dict | None:
|
|
||||||
"""Get user with their active subscription info."""
|
|
||||||
return await fetch_one(
|
|
||||||
"""
|
|
||||||
SELECT u.*, s.plan, s.status as sub_status, s.current_period_end
|
|
||||||
FROM users u
|
|
||||||
LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active'
|
|
||||||
WHERE u.id = ? AND u.deleted_at IS NULL
|
|
||||||
""",
|
|
||||||
(user_id,)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Decorators
|
# Decorators
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -116,24 +103,69 @@ def login_required(f):
|
|||||||
return decorated
|
return decorated
|
||||||
|
|
||||||
|
|
||||||
def subscription_required(plans: list[str] = None):
|
def role_required(*roles):
|
||||||
"""Require active subscription, optionally of specific plan(s)."""
|
"""Require user to have at least one of the given roles."""
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
async def decorated(*args, **kwargs):
|
||||||
|
if not g.get("user"):
|
||||||
|
await flash("Please sign in to continue.", "warning")
|
||||||
|
return redirect(url_for("auth.login", next=request.path))
|
||||||
|
user_roles = g.user.get("roles", [])
|
||||||
|
if not any(r in user_roles for r in roles):
|
||||||
|
await flash("You don't have permission to access that page.", "error")
|
||||||
|
return redirect(url_for("dashboard.index"))
|
||||||
|
return await f(*args, **kwargs)
|
||||||
|
return decorated
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
async def grant_role(user_id: int, role: str) -> None:
|
||||||
|
"""Grant a role to a user (idempotent)."""
|
||||||
|
await execute(
|
||||||
|
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
|
||||||
|
(user_id, role),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def revoke_role(user_id: int, role: str) -> None:
|
||||||
|
"""Revoke a role from a user."""
|
||||||
|
await execute(
|
||||||
|
"DELETE FROM user_roles WHERE user_id = ? AND role = ?",
|
||||||
|
(user_id, role),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_admin_role(user_id: int, email: str) -> None:
|
||||||
|
"""Grant admin role if email is in ADMIN_EMAILS."""
|
||||||
|
if email.lower() in config.ADMIN_EMAILS:
|
||||||
|
await grant_role(user_id, "admin")
|
||||||
|
|
||||||
|
|
||||||
|
def subscription_required(
|
||||||
|
plans: list[str] = None,
|
||||||
|
allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"),
|
||||||
|
):
|
||||||
|
"""Require active subscription, optionally of specific plan(s) and/or statuses.
|
||||||
|
|
||||||
|
Reads from g.subscription (eager-loaded in load_user) — zero extra queries.
|
||||||
|
"""
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
async def decorated(*args, **kwargs):
|
async def decorated(*args, **kwargs):
|
||||||
if not g.get("user"):
|
if not g.get("user"):
|
||||||
await flash("Please sign in to continue.", "warning")
|
await flash("Please sign in to continue.", "warning")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
user = await get_user_with_subscription(g.user["id"])
|
sub = g.get("subscription")
|
||||||
if not user or not user.get("plan"):
|
if not sub or sub["status"] not in allowed:
|
||||||
await flash("Please subscribe to access this feature.", "warning")
|
await flash("Please subscribe to access this feature.", "warning")
|
||||||
return redirect(url_for("billing.pricing"))
|
return redirect(url_for("billing.pricing"))
|
||||||
|
|
||||||
if plans and user["plan"] not in plans:
|
if plans and sub["plan"] not in plans:
|
||||||
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
|
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
|
||||||
return redirect(url_for("billing.pricing"))
|
return redirect(url_for("billing.pricing"))
|
||||||
|
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
return decorated
|
return decorated
|
||||||
return decorator
|
return decorator
|
||||||
@@ -149,33 +181,33 @@ async def login():
|
|||||||
"""Login page - request magic link."""
|
"""Login page - request magic link."""
|
||||||
if g.get("user"):
|
if g.get("user"):
|
||||||
return redirect(url_for("dashboard.index"))
|
return redirect(url_for("dashboard.index"))
|
||||||
|
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
form = await request.form
|
form = await request.form
|
||||||
email = form.get("email", "").strip().lower()
|
email = form.get("email", "").strip().lower()
|
||||||
|
|
||||||
if not email or "@" not in email:
|
if not email or "@" not in email:
|
||||||
await flash("Please enter a valid email address.", "error")
|
await flash("Please enter a valid email address.", "error")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
# Get or create user
|
# Get or create user
|
||||||
user = await get_user_by_email(email)
|
user = await get_user_by_email(email)
|
||||||
if not user:
|
if not user:
|
||||||
user_id = await create_user(email)
|
user_id = await create_user(email)
|
||||||
else:
|
else:
|
||||||
user_id = user["id"]
|
user_id = user["id"]
|
||||||
|
|
||||||
# Create magic link token
|
# Create magic link token
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
await create_auth_token(user_id, token)
|
await create_auth_token(user_id, token)
|
||||||
|
|
||||||
# Queue email
|
# Queue email
|
||||||
from ..worker import enqueue
|
from ..worker import enqueue
|
||||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||||
|
|
||||||
await flash("Check your email for the sign-in link!", "success")
|
await flash("Check your email for the sign-in link!", "success")
|
||||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||||
|
|
||||||
return await render_template("login.html")
|
return await render_template("login.html")
|
||||||
|
|
||||||
|
|
||||||
@@ -185,39 +217,39 @@ async def signup():
|
|||||||
"""Signup page - same as login but with different messaging."""
|
"""Signup page - same as login but with different messaging."""
|
||||||
if g.get("user"):
|
if g.get("user"):
|
||||||
return redirect(url_for("dashboard.index"))
|
return redirect(url_for("dashboard.index"))
|
||||||
|
|
||||||
plan = request.args.get("plan", "free")
|
plan = request.args.get("plan", "free")
|
||||||
|
|
||||||
if request.method == "POST":
|
if request.method == "POST":
|
||||||
form = await request.form
|
form = await request.form
|
||||||
email = form.get("email", "").strip().lower()
|
email = form.get("email", "").strip().lower()
|
||||||
selected_plan = form.get("plan", "free")
|
selected_plan = form.get("plan", "free")
|
||||||
|
|
||||||
if not email or "@" not in email:
|
if not email or "@" not in email:
|
||||||
await flash("Please enter a valid email address.", "error")
|
await flash("Please enter a valid email address.", "error")
|
||||||
return redirect(url_for("auth.signup", plan=selected_plan))
|
return redirect(url_for("auth.signup", plan=selected_plan))
|
||||||
|
|
||||||
# Check if user exists
|
# Check if user exists
|
||||||
user = await get_user_by_email(email)
|
user = await get_user_by_email(email)
|
||||||
if user:
|
if user:
|
||||||
await flash("Account already exists. Please sign in.", "info")
|
await flash("Account already exists. Please sign in.", "info")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
# Create user
|
# Create user
|
||||||
user_id = await create_user(email)
|
user_id = await create_user(email)
|
||||||
|
|
||||||
# Create magic link token
|
# Create magic link token
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
await create_auth_token(user_id, token)
|
await create_auth_token(user_id, token)
|
||||||
|
|
||||||
# Queue emails
|
# Queue emails
|
||||||
from ..worker import enqueue
|
from ..worker import enqueue
|
||||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||||
await enqueue("send_welcome", {"email": email})
|
await enqueue("send_welcome", {"email": email})
|
||||||
|
|
||||||
await flash("Check your email to complete signup!", "success")
|
await flash("Check your email to complete signup!", "success")
|
||||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||||
|
|
||||||
return await render_template("signup.html", plan=plan)
|
return await render_template("signup.html", plan=plan)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,29 +257,32 @@ async def signup():
|
|||||||
async def verify():
|
async def verify():
|
||||||
"""Verify magic link token."""
|
"""Verify magic link token."""
|
||||||
token = request.args.get("token")
|
token = request.args.get("token")
|
||||||
|
|
||||||
if not token:
|
if not token:
|
||||||
await flash("Invalid or expired link.", "error")
|
await flash("Invalid or expired link.", "error")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
token_data = await get_valid_token(token)
|
token_data = await get_valid_token(token)
|
||||||
|
|
||||||
if not token_data:
|
if not token_data:
|
||||||
await flash("Invalid or expired link. Please request a new one.", "error")
|
await flash("Invalid or expired link. Please request a new one.", "error")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
# Mark token as used
|
# Mark token as used
|
||||||
await mark_token_used(token_data["id"])
|
await mark_token_used(token_data["id"])
|
||||||
|
|
||||||
# Update last login
|
# Update last login
|
||||||
await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat())
|
await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat())
|
||||||
|
|
||||||
# Set session
|
# Set session
|
||||||
session.permanent = True
|
session.permanent = True
|
||||||
session["user_id"] = token_data["user_id"]
|
session["user_id"] = token_data["user_id"]
|
||||||
|
|
||||||
|
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||||
|
await ensure_admin_role(token_data["user_id"], token_data["email"])
|
||||||
|
|
||||||
await flash("Successfully signed in!", "success")
|
await flash("Successfully signed in!", "success")
|
||||||
|
|
||||||
# Redirect to intended page or dashboard
|
# Redirect to intended page or dashboard
|
||||||
next_url = request.args.get("next", url_for("dashboard.index"))
|
next_url = request.args.get("next", url_for("dashboard.index"))
|
||||||
return redirect(next_url)
|
return redirect(next_url)
|
||||||
@@ -274,18 +309,21 @@ async def dev_login():
|
|||||||
"""Instant login for development. Only works in DEBUG mode."""
|
"""Instant login for development. Only works in DEBUG mode."""
|
||||||
if not config.DEBUG:
|
if not config.DEBUG:
|
||||||
return "Not available", 404
|
return "Not available", 404
|
||||||
|
|
||||||
email = request.args.get("email", "dev@localhost")
|
email = request.args.get("email", "dev@localhost")
|
||||||
|
|
||||||
user = await get_user_by_email(email)
|
user = await get_user_by_email(email)
|
||||||
if not user:
|
if not user:
|
||||||
user_id = await create_user(email)
|
user_id = await create_user(email)
|
||||||
else:
|
else:
|
||||||
user_id = user["id"]
|
user_id = user["id"]
|
||||||
|
|
||||||
session.permanent = True
|
session.permanent = True
|
||||||
session["user_id"] = user_id
|
session["user_id"] = user_id
|
||||||
|
|
||||||
|
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||||
|
await ensure_admin_role(user_id, email)
|
||||||
|
|
||||||
await flash(f"Dev login as {email}", "success")
|
await flash(f"Dev login as {email}", "success")
|
||||||
return redirect(url_for("dashboard.index"))
|
return redirect(url_for("dashboard.index"))
|
||||||
|
|
||||||
@@ -296,19 +334,19 @@ async def resend():
|
|||||||
"""Resend magic link."""
|
"""Resend magic link."""
|
||||||
form = await request.form
|
form = await request.form
|
||||||
email = form.get("email", "").strip().lower()
|
email = form.get("email", "").strip().lower()
|
||||||
|
|
||||||
if not email:
|
if not email:
|
||||||
await flash("Email address required.", "error")
|
await flash("Email address required.", "error")
|
||||||
return redirect(url_for("auth.login"))
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
user = await get_user_by_email(email)
|
user = await get_user_by_email(email)
|
||||||
if user:
|
if user:
|
||||||
token = secrets.token_urlsafe(32)
|
token = secrets.token_urlsafe(32)
|
||||||
await create_auth_token(user["id"], token)
|
await create_auth_token(user["id"], token)
|
||||||
|
|
||||||
from ..worker import enqueue
|
from ..worker import enqueue
|
||||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||||
|
|
||||||
# Always show success (don't reveal if email exists)
|
# Always show success (don't reveal if email exists)
|
||||||
await flash("If that email is registered, we've sent a new link.", "success")
|
await flash("If that email is registered, we've sent a new link.", "success")
|
||||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||||
|
|||||||
@@ -2,16 +2,19 @@
|
|||||||
Core infrastructure: database, config, email, and shared utilities.
|
Core infrastructure: database, config, email, and shared utilities.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import secrets
|
import secrets
|
||||||
import hashlib
|
import hashlib
|
||||||
import hmac
|
import hmac
|
||||||
|
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
import resend
|
||||||
import httpx
|
import httpx
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from quart import request, session, g
|
from quart import g, make_response, request, session
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# web/.env is three levels up from web/src/beanflows/core.py
|
# web/.env is three levels up from web/src/beanflows/core.py
|
||||||
@@ -26,27 +29,35 @@ class Config:
|
|||||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production")
|
SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production")
|
||||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001")
|
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001")
|
||||||
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
||||||
|
|
||||||
DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db")
|
DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db")
|
||||||
|
|
||||||
MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15"))
|
MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15"))
|
||||||
SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30"))
|
SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30"))
|
||||||
|
|
||||||
PAYMENT_PROVIDER: str = "paddle"
|
PAYMENT_PROVIDER: str = "paddle"
|
||||||
|
|
||||||
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
|
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
|
||||||
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
|
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
|
||||||
|
PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||||
PADDLE_PRICES: dict = {
|
PADDLE_PRICES: dict = {
|
||||||
"starter": os.getenv("PADDLE_PRICE_STARTER", ""),
|
"starter": os.getenv("PADDLE_PRICE_STARTER", ""),
|
||||||
"pro": os.getenv("PADDLE_PRICE_PRO", ""),
|
"pro": os.getenv("PADDLE_PRICE_PRO", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "")
|
||||||
|
UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "")
|
||||||
|
|
||||||
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
|
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
|
||||||
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
|
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
|
||||||
|
|
||||||
|
ADMIN_EMAILS: list[str] = [
|
||||||
|
e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip()
|
||||||
|
]
|
||||||
|
|
||||||
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
|
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
|
||||||
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
|
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
|
||||||
|
|
||||||
PLAN_FEATURES: dict = {
|
PLAN_FEATURES: dict = {
|
||||||
"free": ["dashboard", "coffee_only", "limited_history"],
|
"free": ["dashboard", "coffee_only", "limited_history"],
|
||||||
"starter": ["dashboard", "coffee_only", "full_history", "export", "api"],
|
"starter": ["dashboard", "coffee_only", "full_history", "export", "api"],
|
||||||
@@ -74,10 +85,10 @@ async def init_db(path: str = None) -> None:
|
|||||||
global _db
|
global _db
|
||||||
db_path = path or config.DATABASE_PATH
|
db_path = path or config.DATABASE_PATH
|
||||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
_db = await aiosqlite.connect(db_path)
|
_db = await aiosqlite.connect(db_path)
|
||||||
_db.row_factory = aiosqlite.Row
|
_db.row_factory = aiosqlite.Row
|
||||||
|
|
||||||
await _db.execute("PRAGMA journal_mode=WAL")
|
await _db.execute("PRAGMA journal_mode=WAL")
|
||||||
await _db.execute("PRAGMA foreign_keys=ON")
|
await _db.execute("PRAGMA foreign_keys=ON")
|
||||||
await _db.execute("PRAGMA busy_timeout=5000")
|
await _db.execute("PRAGMA busy_timeout=5000")
|
||||||
@@ -137,11 +148,11 @@ async def execute_many(sql: str, params_list: list[tuple]) -> None:
|
|||||||
|
|
||||||
class transaction:
|
class transaction:
|
||||||
"""Async context manager for transactions."""
|
"""Async context manager for transactions."""
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
self.db = await get_db()
|
self.db = await get_db()
|
||||||
return self.db
|
return self.db
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
await self.db.commit()
|
await self.db.commit()
|
||||||
@@ -153,25 +164,32 @@ class transaction:
|
|||||||
# Email
|
# Email
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
async def send_email(to: str, subject: str, html: str, text: str = None) -> bool:
|
EMAIL_ADDRESSES = {
|
||||||
"""Send email via Resend API."""
|
"transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def send_email(
|
||||||
|
to: str, subject: str, html: str, text: str = None, from_addr: str = None
|
||||||
|
) -> bool:
|
||||||
|
"""Send email via Resend SDK."""
|
||||||
if not config.RESEND_API_KEY:
|
if not config.RESEND_API_KEY:
|
||||||
print(f"[EMAIL] Would send to {to}: {subject}")
|
print(f"[EMAIL] Would send to {to}: {subject}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
resend.api_key = config.RESEND_API_KEY
|
||||||
response = await client.post(
|
try:
|
||||||
"https://api.resend.com/emails",
|
resend.Emails.send({
|
||||||
headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"},
|
"from": from_addr or config.EMAIL_FROM,
|
||||||
json={
|
"to": to,
|
||||||
"from": config.EMAIL_FROM,
|
"subject": subject,
|
||||||
"to": to,
|
"html": html,
|
||||||
"subject": subject,
|
"text": text or html,
|
||||||
"html": html,
|
})
|
||||||
"text": text or html,
|
return True
|
||||||
},
|
except Exception as e:
|
||||||
)
|
print(f"[EMAIL] Error sending to {to}: {e}")
|
||||||
return response.status_code == 200
|
return False
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CSRF Protection
|
# CSRF Protection
|
||||||
@@ -214,34 +232,34 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t
|
|||||||
window = window or config.RATE_LIMIT_WINDOW
|
window = window or config.RATE_LIMIT_WINDOW
|
||||||
now = datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
window_start = now - timedelta(seconds=window)
|
window_start = now - timedelta(seconds=window)
|
||||||
|
|
||||||
# Clean old entries and count recent
|
# Clean old entries and count recent
|
||||||
await execute(
|
await execute(
|
||||||
"DELETE FROM rate_limits WHERE key = ? AND timestamp < ?",
|
"DELETE FROM rate_limits WHERE key = ? AND timestamp < ?",
|
||||||
(key, window_start.isoformat())
|
(key, window_start.isoformat())
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await fetch_one(
|
result = await fetch_one(
|
||||||
"SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?",
|
"SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?",
|
||||||
(key, window_start.isoformat())
|
(key, window_start.isoformat())
|
||||||
)
|
)
|
||||||
count = result["count"] if result else 0
|
count = result["count"] if result else 0
|
||||||
|
|
||||||
info = {
|
info = {
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"remaining": max(0, limit - count - 1),
|
"remaining": max(0, limit - count - 1),
|
||||||
"reset": int((window_start + timedelta(seconds=window)).timestamp()),
|
"reset": int((window_start + timedelta(seconds=window)).timestamp()),
|
||||||
}
|
}
|
||||||
|
|
||||||
if count >= limit:
|
if count >= limit:
|
||||||
return False, info
|
return False, info
|
||||||
|
|
||||||
# Record this request
|
# Record this request
|
||||||
await execute(
|
await execute(
|
||||||
"INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)",
|
"INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)",
|
||||||
(key, now.isoformat())
|
(key, now.isoformat())
|
||||||
)
|
)
|
||||||
|
|
||||||
return True, info
|
return True, info
|
||||||
|
|
||||||
|
|
||||||
@@ -254,13 +272,13 @@ def rate_limit(limit: int = None, window: int = None, key_func=None):
|
|||||||
key = key_func()
|
key = key_func()
|
||||||
else:
|
else:
|
||||||
key = f"ip:{request.remote_addr}"
|
key = f"ip:{request.remote_addr}"
|
||||||
|
|
||||||
allowed, info = await check_rate_limit(key, limit, window)
|
allowed, info = await check_rate_limit(key, limit, window)
|
||||||
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
response = {"error": "Rate limit exceeded", **info}
|
response = {"error": "Rate limit exceeded", **info}
|
||||||
return response, 429
|
return response, 429
|
||||||
|
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
return decorated
|
return decorated
|
||||||
return decorator
|
return decorator
|
||||||
@@ -284,7 +302,7 @@ def setup_request_id(app):
|
|||||||
rid = request.headers.get("X-Request-ID") or secrets.token_hex(8)
|
rid = request.headers.get("X-Request-ID") or secrets.token_hex(8)
|
||||||
request_id_var.set(rid)
|
request_id_var.set(rid)
|
||||||
g.request_id = rid
|
g.request_id = rid
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
async def add_request_id_header(response):
|
async def add_request_id_header(response):
|
||||||
response.headers["X-Request-ID"] = get_request_id()
|
response.headers["X-Request-ID"] = get_request_id()
|
||||||
@@ -294,13 +312,11 @@ def setup_request_id(app):
|
|||||||
# Webhook Signature Verification
|
# Webhook Signature Verification
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
|
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
|
||||||
"""Verify HMAC-SHA256 webhook signature."""
|
"""Verify HMAC-SHA256 webhook signature."""
|
||||||
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
|
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
|
||||||
return hmac.compare_digest(signature, expected)
|
return hmac.compare_digest(signature, expected)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# Soft Delete Helpers
|
# Soft Delete Helpers
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
@@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int:
|
|||||||
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
|
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
|
||||||
(cutoff,)
|
(cutoff,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# A/B Testing
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def ab_test(experiment: str, variants: tuple = ("control", "treatment")):
|
||||||
|
"""Assign visitor to an A/B test variant via cookie, tag Umami pageviews."""
|
||||||
|
def decorator(f):
|
||||||
|
@wraps(f)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
cookie_key = f"ab_{experiment}"
|
||||||
|
assigned = request.cookies.get(cookie_key)
|
||||||
|
if assigned not in variants:
|
||||||
|
assigned = random.choice(variants)
|
||||||
|
|
||||||
|
g.ab_variant = assigned
|
||||||
|
g.ab_tag = f"{experiment}-{assigned}"
|
||||||
|
|
||||||
|
response = await make_response(await f(*args, **kwargs))
|
||||||
|
response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60)
|
||||||
|
return response
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
@@ -1,51 +1,95 @@
|
|||||||
"""
|
"""
|
||||||
Simple migration runner. Runs schema.sql against the database.
|
Sequential migration runner.
|
||||||
"""
|
|
||||||
import sqlite3
|
Replays all migrations in order. All databases — fresh and existing —
|
||||||
from pathlib import Path
|
go through the same path. No schema.sql fast-path.
|
||||||
import os
|
|
||||||
import sys
|
- Scans versions/ for NNNN_*.py files and runs unapplied ones in order
|
||||||
|
- Each migration has an up(conn) function receiving an uncommitted connection
|
||||||
|
- All pending migrations share a single transaction (batch atomicity)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sqlite3
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
# Add parent to path for imports
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
VERSIONS_DIR = Path(__file__).parent / "versions"
|
||||||
|
VERSION_RE = re.compile(r"^(\d{4})_.+\.py$")
|
||||||
|
|
||||||
def migrate():
|
# Derived from the package path: …/src/<slug>/migrations/migrate.py
|
||||||
"""Run migrations."""
|
_PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject"
|
||||||
# Get database path from env or default
|
|
||||||
db_path = os.getenv("DATABASE_PATH", "data/app.db")
|
|
||||||
|
def _discover_versions():
|
||||||
# Ensure directory exists
|
"""Return sorted list of version file stems."""
|
||||||
|
if not VERSIONS_DIR.is_dir():
|
||||||
|
return []
|
||||||
|
versions = []
|
||||||
|
for f in sorted(VERSIONS_DIR.iterdir()):
|
||||||
|
if VERSION_RE.match(f.name):
|
||||||
|
versions.append(f.stem)
|
||||||
|
return versions
|
||||||
|
|
||||||
|
|
||||||
|
def migrate(db_path=None):
|
||||||
|
if db_path is None:
|
||||||
|
db_path = os.getenv("DATABASE_PATH", "data/app.db")
|
||||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Read schema
|
|
||||||
schema_path = Path(__file__).parent / "schema.sql"
|
|
||||||
schema = schema_path.read_text()
|
|
||||||
|
|
||||||
# Connect and execute
|
|
||||||
conn = sqlite3.connect(db_path)
|
conn = sqlite3.connect(db_path)
|
||||||
|
|
||||||
# Enable WAL mode
|
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.execute("PRAGMA foreign_keys=ON")
|
conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
|
||||||
# Run schema
|
# Ensure tracking table exists before anything else
|
||||||
conn.executescript(schema)
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS _migrations (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE NOT NULL,
|
||||||
|
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
)
|
||||||
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
print(f"✓ Migrations complete: {db_path}")
|
versions = _discover_versions()
|
||||||
|
applied = {
|
||||||
# Show tables
|
row[0]
|
||||||
|
for row in conn.execute("SELECT name FROM _migrations").fetchall()
|
||||||
|
}
|
||||||
|
pending = [v for v in versions if v not in applied]
|
||||||
|
|
||||||
|
if pending:
|
||||||
|
for name in pending:
|
||||||
|
print(f" Applying {name}...")
|
||||||
|
mod = importlib.import_module(
|
||||||
|
f"{_PACKAGE}.migrations.versions.{name}"
|
||||||
|
)
|
||||||
|
mod.up(conn)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO _migrations (name) VALUES (?)", (name,)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
print(f"✓ Applied {len(pending)} migration(s): {db_path}")
|
||||||
|
else:
|
||||||
|
print(f"✓ All migrations already applied: {db_path}")
|
||||||
|
|
||||||
|
# Show tables (excluding internal sqlite/fts tables)
|
||||||
cursor = conn.execute(
|
cursor = conn.execute(
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||||
|
" AND name NOT LIKE 'sqlite_%'"
|
||||||
|
" ORDER BY name"
|
||||||
)
|
)
|
||||||
tables = [row[0] for row in cursor.fetchall()]
|
tables = [row[0] for row in cursor.fetchall()]
|
||||||
print(f" Tables: {', '.join(tables)}")
|
print(f" Tables: {', '.join(tables)}")
|
||||||
|
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
-- BeanFlows Database Schema
|
-- BeanFlows Database Schema
|
||||||
-- Run with: python -m beanflows.migrations.migrate
|
-- Run with: python -m beanflows.migrations.migrate
|
||||||
|
|
||||||
|
-- Migration tracking
|
||||||
|
CREATE TABLE IF NOT EXISTS _migrations (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT UNIQUE NOT NULL,
|
||||||
|
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
|
||||||
-- Users
|
-- Users
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
@@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens (
|
|||||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
|
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
|
||||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
|
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
|
||||||
|
|
||||||
|
-- User Roles
|
||||||
|
CREATE TABLE IF NOT EXISTS user_roles (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||||
|
role TEXT NOT NULL,
|
||||||
|
granted_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||||
|
UNIQUE(user_id, role)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role);
|
||||||
|
|
||||||
|
-- Billing Customers (payment provider identity, separate from subscriptions)
|
||||||
|
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
||||||
|
provider_customer_id TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id);
|
||||||
|
|
||||||
-- Subscriptions
|
-- Subscriptions
|
||||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||||
plan TEXT NOT NULL DEFAULT 'free',
|
plan TEXT NOT NULL DEFAULT 'free',
|
||||||
status TEXT NOT NULL DEFAULT 'free',
|
status TEXT NOT NULL DEFAULT 'free',
|
||||||
|
provider_subscription_id TEXT,
|
||||||
paddle_customer_id TEXT,
|
|
||||||
paddle_subscription_id TEXT,
|
|
||||||
|
|
||||||
current_period_end TEXT,
|
current_period_end TEXT,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT
|
updated_at TEXT
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
|
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id);
|
-- Transactions
|
||||||
|
CREATE TABLE IF NOT EXISTS transactions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||||
|
subscription_id INTEGER REFERENCES subscriptions(id),
|
||||||
|
provider_transaction_id TEXT UNIQUE,
|
||||||
|
type TEXT NOT NULL DEFAULT 'payment',
|
||||||
|
amount_cents INTEGER,
|
||||||
|
currency TEXT DEFAULT 'USD',
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
metadata TEXT,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id);
|
||||||
|
|
||||||
-- API Keys
|
-- API Keys
|
||||||
CREATE TABLE IF NOT EXISTS api_keys (
|
CREATE TABLE IF NOT EXISTS api_keys (
|
||||||
@@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks (
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
|
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
|
||||||
|
|
||||||
|
-- Items (example domain entity - replace with your domain)
|
||||||
|
CREATE TABLE IF NOT EXISTS items (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
data TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT,
|
||||||
|
deleted_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at);
|
||||||
|
|
||||||
|
-- Full-text search for items (optional)
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
|
||||||
|
name,
|
||||||
|
data,
|
||||||
|
content='items',
|
||||||
|
content_rowid='id'
|
||||||
|
);
|
||||||
|
|
||||||
|
-- FTS triggers
|
||||||
|
CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN
|
||||||
|
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN
|
||||||
|
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||||
|
END;
|
||||||
|
|
||||||
|
CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN
|
||||||
|
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||||
|
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||||
|
END;
|
||||||
0
web/src/beanflows/scripts/__init__.py
Normal file
0
web/src/beanflows/scripts/__init__.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
|
||||||
|
"""
|
||||||
|
Create Paddle products and prices for BeanFlows.
|
||||||
|
|
||||||
|
Run once per environment (sandbox, then production).
|
||||||
|
Prints resulting price IDs as a .env snippet.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
uv run python -m beanflows.scripts.setup_paddle
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from paddle_billing import Client as PaddleClient
|
||||||
|
from paddle_billing import Environment, Options
|
||||||
|
from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory
|
||||||
|
from paddle_billing.Resources.Prices.Operations import CreatePrice
|
||||||
|
from paddle_billing.Resources.Products.Operations import CreateProduct
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "")
|
||||||
|
PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||||
|
|
||||||
|
if not PADDLE_API_KEY:
|
||||||
|
print("ERROR: Set PADDLE_API_KEY in .env first")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
PRODUCTS = [
|
||||||
|
# Subscriptions
|
||||||
|
{
|
||||||
|
"name": "Starter",
|
||||||
|
"env_key": "PADDLE_PRICE_STARTER",
|
||||||
|
"price": 900,
|
||||||
|
"currency": CurrencyCode.USD,
|
||||||
|
"interval": "month",
|
||||||
|
"type": "subscription",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Pro",
|
||||||
|
"env_key": "PADDLE_PRICE_PRO",
|
||||||
|
"price": 2900,
|
||||||
|
"currency": CurrencyCode.USD,
|
||||||
|
"interval": "month",
|
||||||
|
"type": "subscription",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION
|
||||||
|
paddle = PaddleClient(PADDLE_API_KEY, options=Options(env))
|
||||||
|
|
||||||
|
print(f"Creating products in {PADDLE_ENVIRONMENT}...\n")
|
||||||
|
|
||||||
|
env_lines = []
|
||||||
|
|
||||||
|
for spec in PRODUCTS:
|
||||||
|
# Create product
|
||||||
|
product = paddle.products.create(CreateProduct(
|
||||||
|
name=spec["name"],
|
||||||
|
tax_category=TaxCategory.Standard,
|
||||||
|
))
|
||||||
|
print(f" Product: {spec['name']} -> {product.id}")
|
||||||
|
|
||||||
|
# Create price
|
||||||
|
price_kwargs = {
|
||||||
|
"description": spec["name"],
|
||||||
|
"product_id": product.id,
|
||||||
|
"unit_price": Money(str(spec["price"]), spec["currency"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
if spec["type"] == "subscription":
|
||||||
|
from paddle_billing.Entities.Shared import TimePeriod
|
||||||
|
price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1)
|
||||||
|
|
||||||
|
price = paddle.prices.create(CreatePrice(**price_kwargs))
|
||||||
|
print(f" Price: {spec['env_key']} = {price.id}")
|
||||||
|
|
||||||
|
env_lines.append(f"{spec['env_key']}={price.id}")
|
||||||
|
|
||||||
|
print("\n# --- .env snippet ---")
|
||||||
|
for line in env_lines:
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
Binary file not shown.
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
Binary file not shown.
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||||
|
This license is copied below, and is also available with a FAQ at:
|
||||||
|
http://scripts.sil.org/OFL
|
||||||
|
|
||||||
|
-----------------------------------------------------------
|
||||||
|
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||||
|
-----------------------------------------------------------
|
||||||
|
|
||||||
|
PREAMBLE
|
||||||
|
The goals of the Open Font License (OFL) are to stimulate worldwide
|
||||||
|
development of collaborative font projects, to support the font creation
|
||||||
|
efforts of academic and linguistic communities, and to provide a free and
|
||||||
|
open framework in which fonts may be shared and improved in partnership
|
||||||
|
with others.
|
||||||
|
|
||||||
|
The OFL allows the licensed fonts to be used, studied, modified and
|
||||||
|
redistributed freely as long as they are not sold by themselves. The
|
||||||
|
fonts, including any derivative works, can be bundled, embedded,
|
||||||
|
redistributed and/or sold with any software provided that any reserved
|
||||||
|
names are not used by derivative works. The fonts and derivatives,
|
||||||
|
however, cannot be released under any other type of license. The
|
||||||
|
requirement for fonts to remain under this license does not apply
|
||||||
|
to any document created using the fonts or their derivatives.
|
||||||
|
|
||||||
|
DEFINITIONS
|
||||||
|
"Font Software" refers to the set of files released by the Copyright
|
||||||
|
Holder(s) under this license and clearly marked as such. This may
|
||||||
|
include source files, build scripts and documentation.
|
||||||
|
|
||||||
|
"Reserved Font Name" refers to any names specified as such after the
|
||||||
|
copyright statement(s).
|
||||||
|
|
||||||
|
"Original Version" refers to the collection of Font Software components as
|
||||||
|
distributed by the Copyright Holder(s).
|
||||||
|
|
||||||
|
"Modified Version" refers to any derivative made by adding to, deleting,
|
||||||
|
or substituting -- in part or in whole -- any of the components of the
|
||||||
|
Original Version, by changing formats or by porting the Font Software to a
|
||||||
|
new environment.
|
||||||
|
|
||||||
|
"Author" refers to any designer, engineer, programmer, technical
|
||||||
|
writer or other person who contributed to the Font Software.
|
||||||
|
|
||||||
|
PERMISSION & CONDITIONS
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of the Font Software, to use, study, copy, merge, embed, modify,
|
||||||
|
redistribute, and sell modified and unmodified copies of the Font
|
||||||
|
Software, subject to the following conditions:
|
||||||
|
|
||||||
|
1) Neither the Font Software nor any of its individual components,
|
||||||
|
in Original or Modified Versions, may be sold by itself.
|
||||||
|
|
||||||
|
2) Original or Modified Versions of the Font Software may be bundled,
|
||||||
|
redistributed and/or sold with any software, provided that each copy
|
||||||
|
contains the above copyright notice and this license. These can be
|
||||||
|
included either as stand-alone text files, human-readable headers or
|
||||||
|
in the appropriate machine-readable metadata fields within text or
|
||||||
|
binary files as long as those fields can be easily viewed by the user.
|
||||||
|
|
||||||
|
3) No Modified Version of the Font Software may use the Reserved Font
|
||||||
|
Name(s) unless explicit written permission is granted by the corresponding
|
||||||
|
Copyright Holder. This restriction only applies to the primary font name as
|
||||||
|
presented to the users.
|
||||||
|
|
||||||
|
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
|
||||||
|
Software shall not be used to promote, endorse or advertise any
|
||||||
|
Modified Version, except to acknowledge the contribution(s) of the
|
||||||
|
Copyright Holder(s) and the Author(s) or with their explicit written
|
||||||
|
permission.
|
||||||
|
|
||||||
|
5) The Font Software, modified or unmodified, in part or in whole,
|
||||||
|
must be distributed entirely under this license, and must not be
|
||||||
|
distributed under any other license. The requirement for fonts to
|
||||||
|
remain under this license does not apply to any document created
|
||||||
|
using the Font Software.
|
||||||
|
|
||||||
|
TERMINATION
|
||||||
|
This license becomes null and void if any of the above conditions are
|
||||||
|
not met.
|
||||||
|
|
||||||
|
DISCLAIMER
|
||||||
|
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
|
||||||
|
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
|
||||||
|
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||||
|
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||||
|
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||||
|
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||||
|
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||||
@@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email
|
|||||||
HANDLERS: dict[str, callable] = {}
|
HANDLERS: dict[str, callable] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _email_wrap(body: str) -> str:
|
||||||
|
"""Wrap email body in a branded layout with inline CSS."""
|
||||||
|
return f"""\
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head><meta charset="utf-8"></head>
|
||||||
|
<body style="margin:0;padding:0;background-color:#F8FAFC;font-family:'Inter',Helvetica,Arial,sans-serif;">
|
||||||
|
<table width="100%" cellpadding="0" cellspacing="0" style="background-color:#F8FAFC;padding:40px 0;">
|
||||||
|
<tr><td align="center">
|
||||||
|
<table width="480" cellpadding="0" cellspacing="0" style="background-color:#FFFFFF;border-radius:8px;border:1px solid #E2E8F0;overflow:hidden;">
|
||||||
|
<!-- Header -->
|
||||||
|
<tr><td style="background-color:#0F172A;padding:24px 32px;">
|
||||||
|
<span style="color:#FFFFFF;font-size:18px;font-weight:700;letter-spacing:-0.02em;">{config.APP_NAME}</span>
|
||||||
|
</td></tr>
|
||||||
|
<!-- Body -->
|
||||||
|
<tr><td style="padding:32px;color:#475569;font-size:15px;line-height:1.6;">
|
||||||
|
{body}
|
||||||
|
</td></tr>
|
||||||
|
<!-- Footer -->
|
||||||
|
<tr><td style="padding:20px 32px;border-top:1px solid #E2E8F0;text-align:center;">
|
||||||
|
<span style="color:#94A3B8;font-size:12px;">© {config.APP_NAME} · You received this because you have an account.</span>
|
||||||
|
</td></tr>
|
||||||
|
</table>
|
||||||
|
</td></tr>
|
||||||
|
</table>
|
||||||
|
</body>
|
||||||
|
</html>"""
|
||||||
|
|
||||||
|
|
||||||
|
def _email_button(url: str, label: str) -> str:
|
||||||
|
"""Render a branded CTA button for email."""
|
||||||
|
return (
|
||||||
|
f'<table cellpadding="0" cellspacing="0" style="margin:24px 0;">'
|
||||||
|
f'<tr><td style="background-color:#3B82F6;border-radius:6px;text-align:center;">'
|
||||||
|
f'<a href="{url}" style="display:inline-block;padding:12px 28px;'
|
||||||
|
f'color:#FFFFFF;font-size:15px;font-weight:600;text-decoration:none;">'
|
||||||
|
f'{label}</a></td></tr></table>'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def task(name: str):
|
def task(name: str):
|
||||||
"""Decorator to register a task handler."""
|
"""Decorator to register a task handler."""
|
||||||
def decorator(f):
|
def decorator(f):
|
||||||
@@ -46,7 +86,7 @@ async def get_pending_tasks(limit: int = 10) -> list[dict]:
|
|||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
return await fetch_all(
|
return await fetch_all(
|
||||||
"""
|
"""
|
||||||
SELECT * FROM tasks
|
SELECT * FROM tasks
|
||||||
WHERE status = 'pending' AND run_at <= ?
|
WHERE status = 'pending' AND run_at <= ?
|
||||||
ORDER BY run_at ASC
|
ORDER BY run_at ASC
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
@@ -66,15 +106,15 @@ async def mark_complete(task_id: int) -> None:
|
|||||||
async def mark_failed(task_id: int, error: str, retries: int) -> None:
|
async def mark_failed(task_id: int, error: str, retries: int) -> None:
|
||||||
"""Mark task as failed, schedule retry if attempts remain."""
|
"""Mark task as failed, schedule retry if attempts remain."""
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
|
|
||||||
if retries < max_retries:
|
if retries < max_retries:
|
||||||
# Exponential backoff: 1min, 5min, 25min
|
# Exponential backoff: 1min, 5min, 25min
|
||||||
delay = timedelta(minutes=5 ** retries)
|
delay = timedelta(minutes=5 ** retries)
|
||||||
run_at = datetime.utcnow() + delay
|
run_at = datetime.utcnow() + delay
|
||||||
|
|
||||||
await execute(
|
await execute(
|
||||||
"""
|
"""
|
||||||
UPDATE tasks
|
UPDATE tasks
|
||||||
SET status = 'pending', error = ?, retries = ?, run_at = ?
|
SET status = 'pending', error = ?, retries = ?, run_at = ?
|
||||||
WHERE id = ?
|
WHERE id = ?
|
||||||
""",
|
""",
|
||||||
@@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None:
|
|||||||
subject=payload["subject"],
|
subject=payload["subject"],
|
||||||
html=payload["html"],
|
html=payload["html"],
|
||||||
text=payload.get("text"),
|
text=payload.get("text"),
|
||||||
|
from_addr=payload.get("from_addr"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -106,35 +147,37 @@ async def handle_send_email(payload: dict) -> None:
|
|||||||
async def handle_send_magic_link(payload: dict) -> None:
|
async def handle_send_magic_link(payload: dict) -> None:
|
||||||
"""Send magic link email."""
|
"""Send magic link email."""
|
||||||
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
|
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
|
||||||
|
|
||||||
html = f"""
|
body = (
|
||||||
<h2>Sign in to {config.APP_NAME}</h2>
|
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Sign in to {config.APP_NAME}</h2>'
|
||||||
<p>Click the link below to sign in:</p>
|
f"<p>Click the button below to sign in. This link expires in "
|
||||||
<p><a href="{link}">{link}</a></p>
|
f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>"
|
||||||
<p>This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>
|
f"{_email_button(link, 'Sign In')}"
|
||||||
<p>If you didn't request this, you can safely ignore this email.</p>
|
f'<p style="font-size:13px;color:#94A3B8;">If the button doesn\'t work, copy and paste this URL into your browser:</p>'
|
||||||
"""
|
f'<p style="font-size:13px;color:#94A3B8;word-break:break-all;">{link}</p>'
|
||||||
|
f'<p style="font-size:13px;color:#94A3B8;">If you didn\'t request this, you can safely ignore this email.</p>'
|
||||||
|
)
|
||||||
|
|
||||||
await send_email(
|
await send_email(
|
||||||
to=payload["email"],
|
to=payload["email"],
|
||||||
subject=f"Sign in to {config.APP_NAME}",
|
subject=f"Sign in to {config.APP_NAME}",
|
||||||
html=html,
|
html=_email_wrap(body),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@task("send_welcome")
|
@task("send_welcome")
|
||||||
async def handle_send_welcome(payload: dict) -> None:
|
async def handle_send_welcome(payload: dict) -> None:
|
||||||
"""Send welcome email to new user."""
|
"""Send welcome email to new user."""
|
||||||
html = f"""
|
body = (
|
||||||
<h2>Welcome to {config.APP_NAME}!</h2>
|
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Welcome to {config.APP_NAME}!</h2>'
|
||||||
<p>Thanks for signing up. We're excited to have you.</p>
|
f"<p>Thanks for signing up. We're excited to have you.</p>"
|
||||||
<p><a href="{config.BASE_URL}/dashboard">Go to your dashboard</a></p>
|
f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}'
|
||||||
"""
|
)
|
||||||
|
|
||||||
await send_email(
|
await send_email(
|
||||||
to=payload["email"],
|
to=payload["email"],
|
||||||
subject=f"Welcome to {config.APP_NAME}",
|
subject=f"Welcome to {config.APP_NAME}",
|
||||||
html=html,
|
html=_email_wrap(body),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -173,12 +216,12 @@ async def process_task(task: dict) -> None:
|
|||||||
task_name = task["task_name"]
|
task_name = task["task_name"]
|
||||||
task_id = task["id"]
|
task_id = task["id"]
|
||||||
retries = task.get("retries", 0)
|
retries = task.get("retries", 0)
|
||||||
|
|
||||||
handler = HANDLERS.get(task_name)
|
handler = HANDLERS.get(task_name)
|
||||||
if not handler:
|
if not handler:
|
||||||
await mark_failed(task_id, f"Unknown task: {task_name}", retries)
|
await mark_failed(task_id, f"Unknown task: {task_name}", retries)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = json.loads(task["payload"]) if task["payload"] else {}
|
payload = json.loads(task["payload"]) if task["payload"] else {}
|
||||||
await handler(payload)
|
await handler(payload)
|
||||||
@@ -194,17 +237,17 @@ async def run_worker(poll_interval: float = 1.0) -> None:
|
|||||||
"""Main worker loop."""
|
"""Main worker loop."""
|
||||||
print("[WORKER] Starting...")
|
print("[WORKER] Starting...")
|
||||||
await init_db()
|
await init_db()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
tasks = await get_pending_tasks(limit=10)
|
tasks = await get_pending_tasks(limit=10)
|
||||||
|
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
await process_task(task)
|
await process_task(task)
|
||||||
|
|
||||||
if not tasks:
|
if not tasks:
|
||||||
await asyncio.sleep(poll_interval)
|
await asyncio.sleep(poll_interval)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[WORKER] Error: {e}")
|
print(f"[WORKER] Error: {e}")
|
||||||
await asyncio.sleep(poll_interval * 5)
|
await asyncio.sleep(poll_interval * 5)
|
||||||
@@ -214,16 +257,16 @@ async def run_scheduler() -> None:
|
|||||||
"""Schedule periodic cleanup tasks."""
|
"""Schedule periodic cleanup tasks."""
|
||||||
print("[SCHEDULER] Starting...")
|
print("[SCHEDULER] Starting...")
|
||||||
await init_db()
|
await init_db()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# Schedule cleanup tasks every hour
|
# Schedule cleanup tasks every hour
|
||||||
await enqueue("cleanup_expired_tokens")
|
await enqueue("cleanup_expired_tokens")
|
||||||
await enqueue("cleanup_rate_limits")
|
await enqueue("cleanup_rate_limits")
|
||||||
await enqueue("cleanup_old_tasks")
|
await enqueue("cleanup_old_tasks")
|
||||||
|
|
||||||
await asyncio.sleep(3600) # 1 hour
|
await asyncio.sleep(3600) # 1 hour
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[SCHEDULER] Error: {e}")
|
print(f"[SCHEDULER] Error: {e}")
|
||||||
await asyncio.sleep(60)
|
await asyncio.sleep(60)
|
||||||
@@ -231,8 +274,8 @@ async def run_scheduler() -> None:
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "scheduler":
|
if len(sys.argv) > 1 and sys.argv[1] == "scheduler":
|
||||||
asyncio.run(run_scheduler())
|
asyncio.run(run_scheduler())
|
||||||
else:
|
else:
|
||||||
asyncio.run(run_worker())
|
asyncio.run(run_worker())
|
||||||
@@ -10,9 +10,11 @@ from unittest.mock import AsyncMock, patch
|
|||||||
import aiosqlite
|
import aiosqlite
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from beanflows import analytics, core
|
|
||||||
|
from beanflows import core
|
||||||
from beanflows.app import create_app
|
from beanflows.app import create_app
|
||||||
|
|
||||||
|
|
||||||
SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql"
|
SCHEMA_PATH = Path(__file__).parent.parent / "src" / "beanflows" / "migrations" / "schema.sql"
|
||||||
|
|
||||||
|
|
||||||
@@ -44,9 +46,7 @@ async def db():
|
|||||||
async def app(db):
|
async def app(db):
|
||||||
"""Quart app with DB already initialized (init_db/close_db patched to no-op)."""
|
"""Quart app with DB already initialized (init_db/close_db patched to no-op)."""
|
||||||
with patch.object(core, "init_db", new_callable=AsyncMock), \
|
with patch.object(core, "init_db", new_callable=AsyncMock), \
|
||||||
patch.object(core, "close_db", new_callable=AsyncMock), \
|
patch.object(core, "close_db", new_callable=AsyncMock):
|
||||||
patch.object(analytics, "open_analytics_db"), \
|
|
||||||
patch.object(analytics, "close_analytics_db"):
|
|
||||||
application = create_app()
|
application = create_app()
|
||||||
application.config["TESTING"] = True
|
application.config["TESTING"] = True
|
||||||
yield application
|
yield application
|
||||||
@@ -92,22 +92,17 @@ def create_subscription(db):
|
|||||||
user_id: int,
|
user_id: int,
|
||||||
plan: str = "pro",
|
plan: str = "pro",
|
||||||
status: str = "active",
|
status: str = "active",
|
||||||
|
provider_subscription_id: str = "sub_test456",
|
||||||
paddle_customer_id: str = "ctm_test123",
|
|
||||||
paddle_subscription_id: str = "sub_test456",
|
|
||||||
|
|
||||||
current_period_end: str = "2025-03-01T00:00:00Z",
|
current_period_end: str = "2025-03-01T00:00:00Z",
|
||||||
) -> int:
|
) -> int:
|
||||||
now = datetime.utcnow().isoformat()
|
now = datetime.utcnow().isoformat()
|
||||||
async with db.execute(
|
async with db.execute(
|
||||||
|
|
||||||
"""INSERT INTO subscriptions
|
"""INSERT INTO subscriptions
|
||||||
(user_id, plan, status, paddle_customer_id,
|
(user_id, plan, status,
|
||||||
paddle_subscription_id, current_period_end, created_at, updated_at)
|
provider_subscription_id, current_period_end, created_at, updated_at)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
VALUES (?, ?, ?, ?, ?, ?, ?)""",
|
||||||
(user_id, plan, status, paddle_customer_id, paddle_subscription_id,
|
(user_id, plan, status, provider_subscription_id,
|
||||||
current_period_end, now, now),
|
current_period_end, now, now),
|
||||||
|
|
||||||
) as cursor:
|
) as cursor:
|
||||||
sub_id = cursor.lastrowid
|
sub_id = cursor.lastrowid
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -115,6 +110,48 @@ def create_subscription(db):
|
|||||||
return _create
|
return _create
|
||||||
|
|
||||||
|
|
||||||
|
# ── Billing Customers ───────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_billing_customer(db):
|
||||||
|
"""Factory: create a billing_customers row for a user."""
|
||||||
|
async def _create(user_id: int, provider_customer_id: str = "cust_test123") -> int:
|
||||||
|
async with db.execute(
|
||||||
|
"""INSERT INTO billing_customers (user_id, provider_customer_id)
|
||||||
|
VALUES (?, ?)
|
||||||
|
ON CONFLICT(user_id) DO UPDATE SET provider_customer_id = excluded.provider_customer_id""",
|
||||||
|
(user_id, provider_customer_id),
|
||||||
|
) as cursor:
|
||||||
|
row_id = cursor.lastrowid
|
||||||
|
await db.commit()
|
||||||
|
return row_id
|
||||||
|
return _create
|
||||||
|
|
||||||
|
|
||||||
|
# ── Roles ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def grant_role(db):
|
||||||
|
"""Factory: grant a role to a user."""
|
||||||
|
async def _grant(user_id: int, role: str) -> None:
|
||||||
|
await db.execute(
|
||||||
|
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
|
||||||
|
(user_id, role),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
return _grant
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def admin_client(app, test_user, grant_role):
|
||||||
|
"""Test client with admin role and session['user_id'] pre-set."""
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
async with app.test_client() as c:
|
||||||
|
async with c.session_transaction() as sess:
|
||||||
|
sess["user_id"] = test_user["id"]
|
||||||
|
yield c
|
||||||
|
|
||||||
|
|
||||||
# ── Config ───────────────────────────────────────────────────
|
# ── Config ───────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@@ -127,6 +164,7 @@ def patch_config():
|
|||||||
|
|
||||||
"PADDLE_API_KEY": "test_api_key_123",
|
"PADDLE_API_KEY": "test_api_key_123",
|
||||||
"PADDLE_WEBHOOK_SECRET": "whsec_test_secret",
|
"PADDLE_WEBHOOK_SECRET": "whsec_test_secret",
|
||||||
|
"PADDLE_ENVIRONMENT": "sandbox",
|
||||||
"PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"},
|
"PADDLE_PRICES": {"starter": "pri_starter_123", "pro": "pri_pro_456"},
|
||||||
|
|
||||||
"BASE_URL": "http://localhost:5000",
|
"BASE_URL": "http://localhost:5000",
|
||||||
@@ -147,6 +185,32 @@ def patch_config():
|
|||||||
# ── Webhook helpers ──────────────────────────────────────────
|
# ── Webhook helpers ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_paddle_verifier(monkeypatch):
|
||||||
|
"""Mock Paddle's webhook Verifier to accept test payloads."""
|
||||||
|
def mock_verify(self, payload, secret, signature):
|
||||||
|
if not signature or signature == "invalid_signature":
|
||||||
|
raise ValueError("Invalid signature")
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"paddle_billing.Notifications.Verifier.verify",
|
||||||
|
mock_verify,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_paddle_client(monkeypatch):
|
||||||
|
"""Mock _paddle_client() to return a fake PaddleClient."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"beanflows.billing.routes._paddle_client",
|
||||||
|
lambda: mock_client,
|
||||||
|
)
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
def make_webhook_payload(
|
def make_webhook_payload(
|
||||||
event_type: str,
|
event_type: str,
|
||||||
subscription_id: str = "sub_test456",
|
subscription_id: str = "sub_test456",
|
||||||
@@ -172,76 +236,8 @@ def make_webhook_payload(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def sign_payload(payload_bytes: bytes, secret: str = "whsec_test_secret") -> str:
|
def sign_payload(payload_bytes: bytes) -> str:
|
||||||
"""Compute HMAC-SHA256 signature for a webhook payload."""
|
"""Return a dummy signature for Paddle webhook tests (Verifier is mocked)."""
|
||||||
return hmac.new(secret.encode(), payload_bytes, hashlib.sha256).hexdigest()
|
return "ts=1234567890;h1=dummy_signature"
|
||||||
|
|
||||||
|
|
||||||
# ── Analytics mock data ──────────────────────────────────────
|
|
||||||
|
|
||||||
MOCK_TIME_SERIES = [
|
|
||||||
{"market_year": 2018, "Production": 165000, "Exports": 115000, "Imports": 105000,
|
|
||||||
"Ending_Stocks": 33000, "Total_Distribution": 160000},
|
|
||||||
{"market_year": 2019, "Production": 168000, "Exports": 118000, "Imports": 108000,
|
|
||||||
"Ending_Stocks": 34000, "Total_Distribution": 163000},
|
|
||||||
{"market_year": 2020, "Production": 170000, "Exports": 120000, "Imports": 110000,
|
|
||||||
"Ending_Stocks": 35000, "Total_Distribution": 165000},
|
|
||||||
{"market_year": 2021, "Production": 175000, "Exports": 125000, "Imports": 115000,
|
|
||||||
"Ending_Stocks": 36000, "Total_Distribution": 170000},
|
|
||||||
{"market_year": 2022, "Production": 172000, "Exports": 122000, "Imports": 112000,
|
|
||||||
"Ending_Stocks": 34000, "Total_Distribution": 168000},
|
|
||||||
]
|
|
||||||
|
|
||||||
MOCK_TOP_COUNTRIES = [
|
|
||||||
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022, "Production": 65000},
|
|
||||||
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022, "Production": 30000},
|
|
||||||
{"country_name": "Colombia", "country_code": "CO", "market_year": 2022, "Production": 14000},
|
|
||||||
]
|
|
||||||
|
|
||||||
MOCK_STU_TREND = [
|
|
||||||
{"market_year": 2020, "Stock_to_Use_Ratio_pct": 21.2},
|
|
||||||
{"market_year": 2021, "Stock_to_Use_Ratio_pct": 21.1},
|
|
||||||
{"market_year": 2022, "Stock_to_Use_Ratio_pct": 20.2},
|
|
||||||
]
|
|
||||||
|
|
||||||
MOCK_BALANCE = [
|
|
||||||
{"market_year": 2020, "Production": 170000, "Total_Distribution": 165000, "Supply_Demand_Balance": 5000},
|
|
||||||
{"market_year": 2021, "Production": 175000, "Total_Distribution": 170000, "Supply_Demand_Balance": 5000},
|
|
||||||
{"market_year": 2022, "Production": 172000, "Total_Distribution": 168000, "Supply_Demand_Balance": 4000},
|
|
||||||
]
|
|
||||||
|
|
||||||
MOCK_YOY = [
|
|
||||||
{"country_name": "Brazil", "country_code": "BR", "market_year": 2022,
|
|
||||||
"Production": 65000, "Production_YoY_pct": -3.5},
|
|
||||||
{"country_name": "Vietnam", "country_code": "VN", "market_year": 2022,
|
|
||||||
"Production": 30000, "Production_YoY_pct": 2.1},
|
|
||||||
]
|
|
||||||
|
|
||||||
MOCK_COMMODITIES = [
|
|
||||||
{"commodity_code": 711100, "commodity_name": "Coffee, Green"},
|
|
||||||
{"commodity_code": 222000, "commodity_name": "Soybeans"},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_analytics():
|
|
||||||
"""Patch all analytics query functions with mock data."""
|
|
||||||
with patch.object(analytics, "get_global_time_series", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_TIME_SERIES), \
|
|
||||||
patch.object(analytics, "get_top_countries", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_TOP_COUNTRIES), \
|
|
||||||
patch.object(analytics, "get_stock_to_use_trend", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_STU_TREND), \
|
|
||||||
patch.object(analytics, "get_supply_demand_balance", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_BALANCE), \
|
|
||||||
patch.object(analytics, "get_production_yoy_by_country", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_YOY), \
|
|
||||||
patch.object(analytics, "get_country_comparison", new_callable=AsyncMock,
|
|
||||||
return_value=[]), \
|
|
||||||
patch.object(analytics, "get_available_commodities", new_callable=AsyncMock,
|
|
||||||
return_value=MOCK_COMMODITIES), \
|
|
||||||
patch.object(analytics, "fetch_analytics", new_callable=AsyncMock,
|
|
||||||
return_value=[{"result": 1}]):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -9,10 +9,13 @@ from hypothesis import strategies as st
|
|||||||
from beanflows.billing.routes import (
|
from beanflows.billing.routes import (
|
||||||
|
|
||||||
can_access_feature,
|
can_access_feature,
|
||||||
|
get_billing_customer,
|
||||||
get_subscription,
|
get_subscription,
|
||||||
get_subscription_by_provider_id,
|
get_subscription_by_provider_id,
|
||||||
is_within_limits,
|
is_within_limits,
|
||||||
|
record_transaction,
|
||||||
update_subscription_status,
|
update_subscription_status,
|
||||||
|
upsert_billing_customer,
|
||||||
upsert_subscription,
|
upsert_subscription,
|
||||||
)
|
)
|
||||||
from beanflows.core import config
|
from beanflows.core import config
|
||||||
@@ -45,7 +48,6 @@ class TestUpsertSubscription:
|
|||||||
user_id=test_user["id"],
|
user_id=test_user["id"],
|
||||||
plan="pro",
|
plan="pro",
|
||||||
status="active",
|
status="active",
|
||||||
provider_customer_id="cust_abc",
|
|
||||||
provider_subscription_id="sub_xyz",
|
provider_subscription_id="sub_xyz",
|
||||||
current_period_end="2025-06-01T00:00:00Z",
|
current_period_end="2025-06-01T00:00:00Z",
|
||||||
)
|
)
|
||||||
@@ -53,39 +55,53 @@ class TestUpsertSubscription:
|
|||||||
row = await get_subscription(test_user["id"])
|
row = await get_subscription(test_user["id"])
|
||||||
assert row["plan"] == "pro"
|
assert row["plan"] == "pro"
|
||||||
assert row["status"] == "active"
|
assert row["status"] == "active"
|
||||||
|
assert row["provider_subscription_id"] == "sub_xyz"
|
||||||
assert row["paddle_customer_id"] == "cust_abc"
|
|
||||||
assert row["paddle_subscription_id"] == "sub_xyz"
|
|
||||||
|
|
||||||
assert row["current_period_end"] == "2025-06-01T00:00:00Z"
|
assert row["current_period_end"] == "2025-06-01T00:00:00Z"
|
||||||
|
|
||||||
async def test_update_existing_subscription(self, db, test_user, create_subscription):
|
async def test_update_existing_by_provider_subscription_id(self, db, test_user):
|
||||||
original_id = await create_subscription(
|
"""upsert finds existing by provider_subscription_id, not user_id."""
|
||||||
test_user["id"], plan="starter", status="active",
|
await upsert_subscription(
|
||||||
|
user_id=test_user["id"],
|
||||||
paddle_subscription_id="sub_old",
|
plan="starter",
|
||||||
|
status="active",
|
||||||
|
provider_subscription_id="sub_same",
|
||||||
)
|
)
|
||||||
returned_id = await upsert_subscription(
|
returned_id = await upsert_subscription(
|
||||||
user_id=test_user["id"],
|
user_id=test_user["id"],
|
||||||
plan="pro",
|
plan="pro",
|
||||||
status="active",
|
status="active",
|
||||||
provider_customer_id="cust_new",
|
provider_subscription_id="sub_same",
|
||||||
provider_subscription_id="sub_new",
|
|
||||||
)
|
)
|
||||||
assert returned_id == original_id
|
|
||||||
row = await get_subscription(test_user["id"])
|
row = await get_subscription(test_user["id"])
|
||||||
assert row["plan"] == "pro"
|
assert row["plan"] == "pro"
|
||||||
|
assert row["provider_subscription_id"] == "sub_same"
|
||||||
|
|
||||||
assert row["paddle_subscription_id"] == "sub_new"
|
async def test_different_provider_id_creates_new(self, db, test_user):
|
||||||
|
"""Different provider_subscription_id creates a new row (multi-sub support)."""
|
||||||
|
await upsert_subscription(
|
||||||
|
user_id=test_user["id"],
|
||||||
|
plan="starter",
|
||||||
|
status="active",
|
||||||
|
provider_subscription_id="sub_first",
|
||||||
|
)
|
||||||
|
await upsert_subscription(
|
||||||
|
user_id=test_user["id"],
|
||||||
|
plan="pro",
|
||||||
|
status="active",
|
||||||
|
provider_subscription_id="sub_second",
|
||||||
|
)
|
||||||
|
from beanflows.core import fetch_all
|
||||||
|
rows = await fetch_all(
|
||||||
|
"SELECT * FROM subscriptions WHERE user_id = ? ORDER BY created_at",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert len(rows) == 2
|
||||||
|
|
||||||
async def test_upsert_with_none_period_end(self, db, test_user):
|
async def test_upsert_with_none_period_end(self, db, test_user):
|
||||||
await upsert_subscription(
|
await upsert_subscription(
|
||||||
user_id=test_user["id"],
|
user_id=test_user["id"],
|
||||||
plan="pro",
|
plan="pro",
|
||||||
status="active",
|
status="active",
|
||||||
provider_customer_id="cust_1",
|
|
||||||
provider_subscription_id="sub_1",
|
provider_subscription_id="sub_1",
|
||||||
current_period_end=None,
|
current_period_end=None,
|
||||||
)
|
)
|
||||||
@@ -93,6 +109,28 @@ class TestUpsertSubscription:
|
|||||||
assert row["current_period_end"] is None
|
assert row["current_period_end"] is None
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# upsert_billing_customer / get_billing_customer
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestUpsertBillingCustomer:
|
||||||
|
async def test_creates_billing_customer(self, db, test_user):
|
||||||
|
await upsert_billing_customer(test_user["id"], "cust_abc")
|
||||||
|
row = await get_billing_customer(test_user["id"])
|
||||||
|
assert row is not None
|
||||||
|
assert row["provider_customer_id"] == "cust_abc"
|
||||||
|
|
||||||
|
async def test_updates_existing_customer(self, db, test_user):
|
||||||
|
await upsert_billing_customer(test_user["id"], "cust_old")
|
||||||
|
await upsert_billing_customer(test_user["id"], "cust_new")
|
||||||
|
row = await get_billing_customer(test_user["id"])
|
||||||
|
assert row["provider_customer_id"] == "cust_new"
|
||||||
|
|
||||||
|
async def test_get_returns_none_for_unknown_user(self, db):
|
||||||
|
row = await get_billing_customer(99999)
|
||||||
|
assert row is None
|
||||||
|
|
||||||
|
|
||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
# get_subscription_by_provider_id
|
# get_subscription_by_provider_id
|
||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
@@ -102,10 +140,8 @@ class TestGetSubscriptionByProviderId:
|
|||||||
result = await get_subscription_by_provider_id("nonexistent")
|
result = await get_subscription_by_provider_id("nonexistent")
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
async def test_finds_by_provider_subscription_id(self, db, test_user, create_subscription):
|
||||||
async def test_finds_by_paddle_subscription_id(self, db, test_user, create_subscription):
|
await create_subscription(test_user["id"], provider_subscription_id="sub_findme")
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_findme")
|
|
||||||
|
|
||||||
result = await get_subscription_by_provider_id("sub_findme")
|
result = await get_subscription_by_provider_id("sub_findme")
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result["user_id"] == test_user["id"]
|
assert result["user_id"] == test_user["id"]
|
||||||
@@ -117,18 +153,14 @@ class TestGetSubscriptionByProviderId:
|
|||||||
|
|
||||||
class TestUpdateSubscriptionStatus:
|
class TestUpdateSubscriptionStatus:
|
||||||
async def test_updates_status(self, db, test_user, create_subscription):
|
async def test_updates_status(self, db, test_user, create_subscription):
|
||||||
|
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_upd")
|
||||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_upd")
|
|
||||||
|
|
||||||
await update_subscription_status("sub_upd", status="cancelled")
|
await update_subscription_status("sub_upd", status="cancelled")
|
||||||
row = await get_subscription(test_user["id"])
|
row = await get_subscription(test_user["id"])
|
||||||
assert row["status"] == "cancelled"
|
assert row["status"] == "cancelled"
|
||||||
assert row["updated_at"] is not None
|
assert row["updated_at"] is not None
|
||||||
|
|
||||||
async def test_updates_extra_fields(self, db, test_user, create_subscription):
|
async def test_updates_extra_fields(self, db, test_user, create_subscription):
|
||||||
|
await create_subscription(test_user["id"], provider_subscription_id="sub_extra")
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_extra")
|
|
||||||
|
|
||||||
await update_subscription_status(
|
await update_subscription_status(
|
||||||
"sub_extra",
|
"sub_extra",
|
||||||
status="active",
|
status="active",
|
||||||
@@ -141,9 +173,7 @@ class TestUpdateSubscriptionStatus:
|
|||||||
assert row["current_period_end"] == "2026-01-01T00:00:00Z"
|
assert row["current_period_end"] == "2026-01-01T00:00:00Z"
|
||||||
|
|
||||||
async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription):
|
async def test_noop_for_unknown_provider_id(self, db, test_user, create_subscription):
|
||||||
|
await create_subscription(test_user["id"], provider_subscription_id="sub_known", status="active")
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_known", status="active")
|
|
||||||
|
|
||||||
await update_subscription_status("sub_unknown", status="expired")
|
await update_subscription_status("sub_unknown", status="expired")
|
||||||
row = await get_subscription(test_user["id"])
|
row = await get_subscription(test_user["id"])
|
||||||
assert row["status"] == "active" # unchanged
|
assert row["status"] == "active" # unchanged
|
||||||
@@ -155,22 +185,22 @@ class TestUpdateSubscriptionStatus:
|
|||||||
|
|
||||||
class TestCanAccessFeature:
|
class TestCanAccessFeature:
|
||||||
async def test_no_subscription_gets_free_features(self, db, test_user):
|
async def test_no_subscription_gets_free_features(self, db, test_user):
|
||||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
assert await can_access_feature(test_user["id"], "basic") is True
|
||||||
assert await can_access_feature(test_user["id"], "export") is False
|
assert await can_access_feature(test_user["id"], "export") is False
|
||||||
assert await can_access_feature(test_user["id"], "api") is False
|
assert await can_access_feature(test_user["id"], "api") is False
|
||||||
|
|
||||||
async def test_active_pro_gets_all_features(self, db, test_user, create_subscription):
|
async def test_active_pro_gets_all_features(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="active")
|
await create_subscription(test_user["id"], plan="pro", status="active")
|
||||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
assert await can_access_feature(test_user["id"], "basic") is True
|
||||||
assert await can_access_feature(test_user["id"], "export") is True
|
assert await can_access_feature(test_user["id"], "export") is True
|
||||||
assert await can_access_feature(test_user["id"], "api") is True
|
assert await can_access_feature(test_user["id"], "api") is True
|
||||||
assert await can_access_feature(test_user["id"], "priority_support") is True
|
assert await can_access_feature(test_user["id"], "priority_support") is True
|
||||||
|
|
||||||
async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription):
|
async def test_active_starter_gets_starter_features(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="starter", status="active")
|
await create_subscription(test_user["id"], plan="starter", status="active")
|
||||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
assert await can_access_feature(test_user["id"], "basic") is True
|
||||||
assert await can_access_feature(test_user["id"], "export") is True
|
assert await can_access_feature(test_user["id"], "export") is True
|
||||||
assert await can_access_feature(test_user["id"], "all_commodities") is False
|
assert await can_access_feature(test_user["id"], "api") is False
|
||||||
|
|
||||||
async def test_cancelled_still_has_features(self, db, test_user, create_subscription):
|
async def test_cancelled_still_has_features(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="cancelled")
|
await create_subscription(test_user["id"], plan="pro", status="cancelled")
|
||||||
@@ -183,7 +213,7 @@ class TestCanAccessFeature:
|
|||||||
async def test_expired_falls_back_to_free(self, db, test_user, create_subscription):
|
async def test_expired_falls_back_to_free(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="expired")
|
await create_subscription(test_user["id"], plan="pro", status="expired")
|
||||||
assert await can_access_feature(test_user["id"], "api") is False
|
assert await can_access_feature(test_user["id"], "api") is False
|
||||||
assert await can_access_feature(test_user["id"], "dashboard") is True
|
assert await can_access_feature(test_user["id"], "basic") is True
|
||||||
|
|
||||||
async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription):
|
async def test_past_due_falls_back_to_free(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="past_due")
|
await create_subscription(test_user["id"], plan="pro", status="past_due")
|
||||||
@@ -203,30 +233,28 @@ class TestCanAccessFeature:
|
|||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
class TestIsWithinLimits:
|
class TestIsWithinLimits:
|
||||||
async def test_free_user_no_api_calls(self, db, test_user):
|
async def test_free_user_within_limits(self, db, test_user):
|
||||||
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
|
assert await is_within_limits(test_user["id"], "items", 50) is True
|
||||||
|
|
||||||
async def test_free_user_commodity_limit(self, db, test_user):
|
async def test_free_user_at_limit(self, db, test_user):
|
||||||
assert await is_within_limits(test_user["id"], "commodities", 0) is True
|
assert await is_within_limits(test_user["id"], "items", 100) is False
|
||||||
assert await is_within_limits(test_user["id"], "commodities", 1) is False
|
|
||||||
|
|
||||||
async def test_free_user_history_limit(self, db, test_user):
|
async def test_free_user_over_limit(self, db, test_user):
|
||||||
assert await is_within_limits(test_user["id"], "history_years", 4) is True
|
assert await is_within_limits(test_user["id"], "items", 150) is False
|
||||||
assert await is_within_limits(test_user["id"], "history_years", 5) is False
|
|
||||||
|
|
||||||
async def test_pro_unlimited(self, db, test_user, create_subscription):
|
async def test_pro_unlimited(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="active")
|
await create_subscription(test_user["id"], plan="pro", status="active")
|
||||||
assert await is_within_limits(test_user["id"], "commodities", 999999) is True
|
assert await is_within_limits(test_user["id"], "items", 999999) is True
|
||||||
assert await is_within_limits(test_user["id"], "api_calls", 999999) is True
|
assert await is_within_limits(test_user["id"], "api_calls", 999999) is True
|
||||||
|
|
||||||
async def test_starter_limits(self, db, test_user, create_subscription):
|
async def test_starter_limits(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="starter", status="active")
|
await create_subscription(test_user["id"], plan="starter", status="active")
|
||||||
assert await is_within_limits(test_user["id"], "api_calls", 9999) is True
|
assert await is_within_limits(test_user["id"], "items", 999) is True
|
||||||
assert await is_within_limits(test_user["id"], "api_calls", 10000) is False
|
assert await is_within_limits(test_user["id"], "items", 1000) is False
|
||||||
|
|
||||||
async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription):
|
async def test_expired_pro_gets_free_limits(self, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], plan="pro", status="expired")
|
await create_subscription(test_user["id"], plan="pro", status="expired")
|
||||||
assert await is_within_limits(test_user["id"], "api_calls", 0) is False
|
assert await is_within_limits(test_user["id"], "items", 100) is False
|
||||||
|
|
||||||
async def test_unknown_resource_returns_false(self, db, test_user):
|
async def test_unknown_resource_returns_false(self, db, test_user):
|
||||||
assert await is_within_limits(test_user["id"], "unicorns", 0) is False
|
assert await is_within_limits(test_user["id"], "unicorns", 0) is False
|
||||||
@@ -238,7 +266,7 @@ class TestIsWithinLimits:
|
|||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"]
|
STATUSES = ["free", "active", "on_trial", "cancelled", "past_due", "paused", "expired"]
|
||||||
FEATURES = ["dashboard", "export", "api", "priority_support"]
|
FEATURES = ["basic", "export", "api", "priority_support"]
|
||||||
ACTIVE_STATUSES = {"active", "on_trial", "cancelled"}
|
ACTIVE_STATUSES = {"active", "on_trial", "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
@@ -282,9 +310,9 @@ async def test_plan_feature_matrix(db, test_user, create_subscription, plan, fea
|
|||||||
|
|
||||||
@pytest.mark.parametrize("plan", PLANS)
|
@pytest.mark.parametrize("plan", PLANS)
|
||||||
@pytest.mark.parametrize("resource,at_limit", [
|
@pytest.mark.parametrize("resource,at_limit", [
|
||||||
("commodities", 1),
|
("items", 100),
|
||||||
("commodities", 65),
|
("items", 1000),
|
||||||
("api_calls", 0),
|
("api_calls", 1000),
|
||||||
("api_calls", 10000),
|
("api_calls", 10000),
|
||||||
])
|
])
|
||||||
async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit):
|
async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resource, at_limit):
|
||||||
@@ -307,11 +335,11 @@ async def test_plan_limit_matrix(db, test_user, create_subscription, plan, resou
|
|||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
class TestLimitsHypothesis:
|
class TestLimitsHypothesis:
|
||||||
@given(count=st.integers(min_value=0, max_value=100))
|
@given(count=st.integers(min_value=0, max_value=10000))
|
||||||
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||||
async def test_free_limit_boundary_commodities(self, db, test_user, count):
|
async def test_free_limit_boundary_items(self, db, test_user, count):
|
||||||
result = await is_within_limits(test_user["id"], "commodities", count)
|
result = await is_within_limits(test_user["id"], "items", count)
|
||||||
assert result == (count < 1)
|
assert result == (count < 100)
|
||||||
|
|
||||||
@given(count=st.integers(min_value=0, max_value=100000))
|
@given(count=st.integers(min_value=0, max_value=100000))
|
||||||
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
@h_settings(max_examples=100, deadline=2000, suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||||
@@ -319,7 +347,56 @@ class TestLimitsHypothesis:
|
|||||||
# Use upsert to avoid duplicate inserts across Hypothesis examples
|
# Use upsert to avoid duplicate inserts across Hypothesis examples
|
||||||
await upsert_subscription(
|
await upsert_subscription(
|
||||||
user_id=test_user["id"], plan="pro", status="active",
|
user_id=test_user["id"], plan="pro", status="active",
|
||||||
provider_customer_id="cust_hyp", provider_subscription_id="sub_hyp",
|
provider_subscription_id="sub_hyp",
|
||||||
)
|
)
|
||||||
result = await is_within_limits(test_user["id"], "commodities", count)
|
result = await is_within_limits(test_user["id"], "items", count)
|
||||||
assert result is True
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# record_transaction
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestRecordTransaction:
|
||||||
|
async def test_inserts_transaction(self, db, test_user):
|
||||||
|
txn_id = await record_transaction(
|
||||||
|
user_id=test_user["id"],
|
||||||
|
provider_transaction_id="txn_abc123",
|
||||||
|
type="payment",
|
||||||
|
amount_cents=2999,
|
||||||
|
currency="EUR",
|
||||||
|
status="completed",
|
||||||
|
)
|
||||||
|
assert txn_id is not None and txn_id > 0
|
||||||
|
|
||||||
|
from beanflows.core import fetch_one
|
||||||
|
row = await fetch_one(
|
||||||
|
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
|
||||||
|
("txn_abc123",),
|
||||||
|
)
|
||||||
|
assert row is not None
|
||||||
|
assert row["user_id"] == test_user["id"]
|
||||||
|
assert row["amount_cents"] == 2999
|
||||||
|
assert row["currency"] == "EUR"
|
||||||
|
assert row["status"] == "completed"
|
||||||
|
|
||||||
|
async def test_idempotent_on_duplicate_provider_id(self, db, test_user):
|
||||||
|
await record_transaction(
|
||||||
|
user_id=test_user["id"],
|
||||||
|
provider_transaction_id="txn_dup",
|
||||||
|
amount_cents=1000,
|
||||||
|
)
|
||||||
|
# Second insert with same provider_transaction_id should be ignored
|
||||||
|
await record_transaction(
|
||||||
|
user_id=test_user["id"],
|
||||||
|
provider_transaction_id="txn_dup",
|
||||||
|
amount_cents=9999,
|
||||||
|
)
|
||||||
|
|
||||||
|
from beanflows.core import fetch_all
|
||||||
|
rows = await fetch_all(
|
||||||
|
"SELECT * FROM transactions WHERE provider_transaction_id = ?",
|
||||||
|
("txn_dup",),
|
||||||
|
)
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0]["amount_cents"] == 1000 # original value preserved
|
||||||
|
|||||||
122
web/tests/test_billing_hooks.py
Normal file
122
web/tests/test_billing_hooks.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Tests for the billing event hook system.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from beanflows.billing.routes import _billing_hooks, _fire_hooks, on_billing_event
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def clear_hooks():
|
||||||
|
"""Ensure hooks are clean before and after each test."""
|
||||||
|
_billing_hooks.clear()
|
||||||
|
yield
|
||||||
|
_billing_hooks.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# Registration
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestOnBillingEvent:
|
||||||
|
def test_registers_single_event(self):
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def my_hook(event_type, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert "subscription.activated" in _billing_hooks
|
||||||
|
assert my_hook in _billing_hooks["subscription.activated"]
|
||||||
|
|
||||||
|
def test_registers_multiple_events(self):
|
||||||
|
@on_billing_event("subscription.activated", "subscription.updated")
|
||||||
|
async def my_hook(event_type, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_hook in _billing_hooks["subscription.activated"]
|
||||||
|
assert my_hook in _billing_hooks["subscription.updated"]
|
||||||
|
|
||||||
|
def test_multiple_hooks_per_event(self):
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def hook_a(event_type, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def hook_b(event_type, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert len(_billing_hooks["subscription.activated"]) == 2
|
||||||
|
|
||||||
|
def test_decorator_returns_original_function(self):
|
||||||
|
@on_billing_event("test_event")
|
||||||
|
async def my_hook(event_type, data):
|
||||||
|
pass
|
||||||
|
|
||||||
|
assert my_hook.__name__ == "my_hook"
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# Firing
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestFireHooks:
|
||||||
|
async def test_fires_registered_hook(self):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def recorder(event_type, data):
|
||||||
|
calls.append((event_type, data))
|
||||||
|
|
||||||
|
await _fire_hooks("subscription.activated", {"id": "sub_123"})
|
||||||
|
assert len(calls) == 1
|
||||||
|
assert calls[0] == ("subscription.activated", {"id": "sub_123"})
|
||||||
|
|
||||||
|
async def test_no_hooks_registered_is_noop(self):
|
||||||
|
# Should not raise
|
||||||
|
await _fire_hooks("unregistered_event", {"id": "sub_123"})
|
||||||
|
|
||||||
|
async def test_fires_all_hooks_for_event(self):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def hook_a(event_type, data):
|
||||||
|
calls.append("a")
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def hook_b(event_type, data):
|
||||||
|
calls.append("b")
|
||||||
|
|
||||||
|
await _fire_hooks("subscription.activated", {})
|
||||||
|
assert calls == ["a", "b"]
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# Error isolation
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestHookErrorIsolation:
|
||||||
|
async def test_failing_hook_does_not_block_others(self):
|
||||||
|
calls = []
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def failing_hook(event_type, data):
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def good_hook(event_type, data):
|
||||||
|
calls.append("ok")
|
||||||
|
|
||||||
|
# Should not raise despite first hook failing
|
||||||
|
await _fire_hooks("subscription.activated", {})
|
||||||
|
assert calls == ["ok"]
|
||||||
|
|
||||||
|
async def test_failing_hook_is_logged(self, caplog):
|
||||||
|
@on_billing_event("subscription.activated")
|
||||||
|
async def bad_hook(event_type, data):
|
||||||
|
raise ValueError("test error")
|
||||||
|
|
||||||
|
import logging
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
await _fire_hooks("subscription.activated", {})
|
||||||
|
|
||||||
|
assert "bad_hook" in caplog.text
|
||||||
|
assert "test error" in caplog.text
|
||||||
@@ -1,12 +1,15 @@
|
|||||||
"""
|
"""
|
||||||
Route integration tests for Paddle billing endpoints.
|
Route integration tests for Paddle billing endpoints.
|
||||||
External Paddle API calls mocked with respx.
|
|
||||||
"""
|
|
||||||
import json
|
|
||||||
|
|
||||||
import httpx
|
Paddle SDK calls mocked via mock_paddle_client fixture.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import respx
|
|
||||||
|
|
||||||
|
|
||||||
CHECKOUT_METHOD = "POST"
|
CHECKOUT_METHOD = "POST"
|
||||||
@@ -54,24 +57,16 @@ class TestCheckoutRoute:
|
|||||||
|
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
async def test_creates_checkout_session(self, auth_client, db, test_user):
|
|
||||||
|
|
||||||
respx.post("https://api.paddle.com/transactions").mock(
|
|
||||||
return_value=httpx.Response(200, json={
|
|
||||||
"data": {
|
|
||||||
"checkout": {
|
|
||||||
"url": "https://checkout.paddle.com/test_123"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
async def test_creates_checkout_session(self, auth_client, db, test_user, mock_paddle_client):
|
||||||
|
mock_txn = MagicMock()
|
||||||
|
mock_txn.checkout.url = "https://checkout.paddle.com/test_123"
|
||||||
|
mock_paddle_client.transactions.create.return_value = mock_txn
|
||||||
|
|
||||||
response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False)
|
response = await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}", follow_redirects=False)
|
||||||
|
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
mock_paddle_client.transactions.create.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
async def test_invalid_plan_rejected(self, auth_client, db, test_user):
|
async def test_invalid_plan_rejected(self, auth_client, db, test_user):
|
||||||
|
|
||||||
@@ -82,20 +77,13 @@ class TestCheckoutRoute:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
@respx.mock
|
async def test_api_error_propagates(self, auth_client, db, test_user, mock_paddle_client):
|
||||||
async def test_api_error_propagates(self, auth_client, db, test_user):
|
mock_paddle_client.transactions.create.side_effect = Exception("API error")
|
||||||
|
with pytest.raises(Exception, match="API error"):
|
||||||
respx.post("https://api.paddle.com/transactions").mock(
|
|
||||||
return_value=httpx.Response(500, json={"error": "server error"})
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(httpx.HTTPStatusError):
|
|
||||||
|
|
||||||
await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}")
|
await auth_client.post(f"/billing/checkout/{CHECKOUT_PLAN}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
# Manage subscription / Portal
|
# Manage subscription / Portal
|
||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
@@ -110,24 +98,18 @@ class TestManageRoute:
|
|||||||
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription):
|
|
||||||
|
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
|
async def test_redirects_to_portal(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
|
||||||
|
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
|
||||||
respx.get("https://api.paddle.com/subscriptions/sub_test").mock(
|
|
||||||
return_value=httpx.Response(200, json={
|
|
||||||
"data": {
|
|
||||||
"management_urls": {
|
|
||||||
"update_payment_method": "https://paddle.com/manage/test_123"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
mock_sub = MagicMock()
|
||||||
|
mock_sub.management_urls.update_payment_method = "https://paddle.com/manage/test_123"
|
||||||
|
mock_paddle_client.subscriptions.get.return_value = mock_sub
|
||||||
|
|
||||||
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
response = await auth_client.post("/billing/manage", follow_redirects=False)
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
mock_paddle_client.subscriptions.get.assert_called_once_with("sub_test")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -145,18 +127,14 @@ class TestCancelRoute:
|
|||||||
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
@respx.mock
|
|
||||||
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription):
|
|
||||||
|
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test")
|
|
||||||
|
|
||||||
respx.post("https://api.paddle.com/subscriptions/sub_test/cancel").mock(
|
|
||||||
return_value=httpx.Response(200, json={"data": {}})
|
|
||||||
)
|
|
||||||
|
|
||||||
|
async def test_cancels_subscription(self, auth_client, db, test_user, create_subscription, mock_paddle_client):
|
||||||
|
await create_subscription(test_user["id"], provider_subscription_id="sub_test")
|
||||||
|
|
||||||
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
response = await auth_client.post("/billing/cancel", follow_redirects=False)
|
||||||
assert response.status_code in (302, 303, 307)
|
assert response.status_code in (302, 303, 307)
|
||||||
|
mock_paddle_client.subscriptions.cancel.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -167,8 +145,9 @@ class TestCancelRoute:
|
|||||||
# subscription_required decorator
|
# subscription_required decorator
|
||||||
# ════════════════════════════════════════════════════════════
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
from beanflows.billing.routes import subscription_required
|
from quart import Blueprint # noqa: E402
|
||||||
from quart import Blueprint
|
|
||||||
|
from beanflows.auth.routes import subscription_required # noqa: E402
|
||||||
|
|
||||||
test_bp = Blueprint("test", __name__)
|
test_bp = Blueprint("test", __name__)
|
||||||
|
|
||||||
|
|||||||
@@ -5,13 +5,13 @@ Covers signature verification, event parsing, subscription lifecycle transitions
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from conftest import make_webhook_payload, sign_payload
|
||||||
|
|
||||||
from hypothesis import HealthCheck, given
|
from hypothesis import HealthCheck, given
|
||||||
from hypothesis import settings as h_settings
|
from hypothesis import settings as h_settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
from beanflows.billing.routes import get_subscription
|
from beanflows.billing.routes import get_billing_customer, get_subscription
|
||||||
|
|
||||||
from conftest import make_webhook_payload, sign_payload
|
|
||||||
|
|
||||||
|
|
||||||
WEBHOOK_PATH = "/billing/webhook/paddle"
|
WEBHOOK_PATH = "/billing/webhook/paddle"
|
||||||
@@ -72,18 +72,19 @@ class TestWebhookSignature:
|
|||||||
|
|
||||||
async def test_modified_payload_rejected(self, client, db, test_user):
|
async def test_modified_payload_rejected(self, client, db, test_user):
|
||||||
|
|
||||||
|
# Paddle SDK Verifier handles tamper detection internally.
|
||||||
|
# We test signature rejection via test_invalid_signature_rejected above.
|
||||||
|
# This test verifies the Verifier is actually called by sending
|
||||||
|
# a payload with an explicitly bad signature.
|
||||||
payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"]))
|
payload = make_webhook_payload("subscription.activated", user_id=str(test_user["id"]))
|
||||||
payload_bytes = json.dumps(payload).encode()
|
payload_bytes = json.dumps(payload).encode()
|
||||||
sig = sign_payload(payload_bytes)
|
|
||||||
tampered = payload_bytes + b"extra"
|
|
||||||
|
|
||||||
# Paddle/LemonSqueezy: HMAC signature verification fails before JSON parsing
|
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
WEBHOOK_PATH,
|
WEBHOOK_PATH,
|
||||||
data=tampered,
|
data=payload_bytes,
|
||||||
headers={SIG_HEADER: sig, "Content-Type": "application/json"},
|
headers={SIG_HEADER: "invalid_signature", "Content-Type": "application/json"},
|
||||||
)
|
)
|
||||||
assert response.status_code in (400, 401)
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
async def test_empty_payload_rejected(self, client, db):
|
async def test_empty_payload_rejected(self, client, db):
|
||||||
@@ -105,7 +106,7 @@ class TestWebhookSignature:
|
|||||||
|
|
||||||
|
|
||||||
class TestWebhookSubscriptionActivated:
|
class TestWebhookSubscriptionActivated:
|
||||||
async def test_creates_subscription(self, client, db, test_user):
|
async def test_creates_subscription_and_billing_customer(self, client, db, test_user):
|
||||||
payload = make_webhook_payload(
|
payload = make_webhook_payload(
|
||||||
"subscription.activated",
|
"subscription.activated",
|
||||||
user_id=str(test_user["id"]),
|
user_id=str(test_user["id"]),
|
||||||
@@ -126,10 +127,14 @@ class TestWebhookSubscriptionActivated:
|
|||||||
assert sub["plan"] == "starter"
|
assert sub["plan"] == "starter"
|
||||||
assert sub["status"] == "active"
|
assert sub["status"] == "active"
|
||||||
|
|
||||||
|
bc = await get_billing_customer(test_user["id"])
|
||||||
|
assert bc is not None
|
||||||
|
assert bc["provider_customer_id"] == "ctm_test123"
|
||||||
|
|
||||||
|
|
||||||
class TestWebhookSubscriptionUpdated:
|
class TestWebhookSubscriptionUpdated:
|
||||||
async def test_updates_subscription_status(self, client, db, test_user, create_subscription):
|
async def test_updates_subscription_status(self, client, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||||
|
|
||||||
payload = make_webhook_payload(
|
payload = make_webhook_payload(
|
||||||
"subscription.updated",
|
"subscription.updated",
|
||||||
@@ -152,7 +157,7 @@ class TestWebhookSubscriptionUpdated:
|
|||||||
|
|
||||||
class TestWebhookSubscriptionCanceled:
|
class TestWebhookSubscriptionCanceled:
|
||||||
async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription):
|
async def test_marks_subscription_cancelled(self, client, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||||
|
|
||||||
payload = make_webhook_payload(
|
payload = make_webhook_payload(
|
||||||
"subscription.canceled",
|
"subscription.canceled",
|
||||||
@@ -174,7 +179,7 @@ class TestWebhookSubscriptionCanceled:
|
|||||||
|
|
||||||
class TestWebhookSubscriptionPastDue:
|
class TestWebhookSubscriptionPastDue:
|
||||||
async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription):
|
async def test_marks_subscription_past_due(self, client, db, test_user, create_subscription):
|
||||||
await create_subscription(test_user["id"], status="active", paddle_subscription_id="sub_test456")
|
await create_subscription(test_user["id"], status="active", provider_subscription_id="sub_test456")
|
||||||
|
|
||||||
payload = make_webhook_payload(
|
payload = make_webhook_payload(
|
||||||
"subscription.past_due",
|
"subscription.past_due",
|
||||||
@@ -209,7 +214,7 @@ class TestWebhookSubscriptionPastDue:
|
|||||||
])
|
])
|
||||||
async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status):
|
async def test_event_status_transitions(client, db, test_user, create_subscription, event_type, expected_status):
|
||||||
if event_type != "subscription.activated":
|
if event_type != "subscription.activated":
|
||||||
await create_subscription(test_user["id"], paddle_subscription_id="sub_test456")
|
await create_subscription(test_user["id"], provider_subscription_id="sub_test456")
|
||||||
|
|
||||||
payload = make_webhook_payload(event_type, user_id=str(test_user["id"]))
|
payload = make_webhook_payload(event_type, user_id=str(test_user["id"]))
|
||||||
payload_bytes = json.dumps(payload).encode()
|
payload_bytes = json.dumps(payload).encode()
|
||||||
|
|||||||
242
web/tests/test_roles.py
Normal file
242
web/tests/test_roles.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Tests for role-based access control: role_required decorator, grant/revoke/ensure_admin_role,
|
||||||
|
and admin route protection.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from quart import Blueprint
|
||||||
|
|
||||||
|
from beanflows.auth.routes import (
|
||||||
|
ensure_admin_role,
|
||||||
|
grant_role,
|
||||||
|
revoke_role,
|
||||||
|
role_required,
|
||||||
|
)
|
||||||
|
from beanflows import core
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# grant_role / revoke_role
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestGrantRole:
|
||||||
|
async def test_grants_role(self, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ?",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is not None
|
||||||
|
assert row["role"] == "admin"
|
||||||
|
|
||||||
|
async def test_idempotent(self, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
rows = await core.fetch_all(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert len(rows) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevokeRole:
|
||||||
|
async def test_revokes_existing_role(self, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
await revoke_role(test_user["id"], "admin")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is None
|
||||||
|
|
||||||
|
async def test_noop_for_missing_role(self, db, test_user):
|
||||||
|
# Should not raise
|
||||||
|
await revoke_role(test_user["id"], "nonexistent")
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# ensure_admin_role
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestEnsureAdminRole:
|
||||||
|
async def test_grants_admin_for_listed_email(self, db, test_user):
|
||||||
|
core.config.ADMIN_EMAILS = ["test@example.com"]
|
||||||
|
try:
|
||||||
|
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is not None
|
||||||
|
finally:
|
||||||
|
core.config.ADMIN_EMAILS = []
|
||||||
|
|
||||||
|
async def test_skips_for_unlisted_email(self, db, test_user):
|
||||||
|
core.config.ADMIN_EMAILS = ["boss@example.com"]
|
||||||
|
try:
|
||||||
|
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is None
|
||||||
|
finally:
|
||||||
|
core.config.ADMIN_EMAILS = []
|
||||||
|
|
||||||
|
async def test_empty_admin_emails_grants_nothing(self, db, test_user):
|
||||||
|
core.config.ADMIN_EMAILS = []
|
||||||
|
await ensure_admin_role(test_user["id"], "test@example.com")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is None
|
||||||
|
|
||||||
|
async def test_case_insensitive_matching(self, db, test_user):
|
||||||
|
core.config.ADMIN_EMAILS = ["test@example.com"]
|
||||||
|
try:
|
||||||
|
await ensure_admin_role(test_user["id"], "Test@Example.COM")
|
||||||
|
row = await core.fetch_one(
|
||||||
|
"SELECT role FROM user_roles WHERE user_id = ? AND role = 'admin'",
|
||||||
|
(test_user["id"],),
|
||||||
|
)
|
||||||
|
assert row is not None
|
||||||
|
finally:
|
||||||
|
core.config.ADMIN_EMAILS = []
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# role_required decorator
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
role_test_bp = Blueprint("role_test", __name__)
|
||||||
|
|
||||||
|
|
||||||
|
@role_test_bp.route("/admin-only")
|
||||||
|
@role_required("admin")
|
||||||
|
async def admin_only_route():
|
||||||
|
return "admin-ok", 200
|
||||||
|
|
||||||
|
|
||||||
|
@role_test_bp.route("/multi-role")
|
||||||
|
@role_required("admin", "editor")
|
||||||
|
async def multi_role_route():
|
||||||
|
return "multi-ok", 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoleRequired:
|
||||||
|
@pytest.fixture
|
||||||
|
async def role_app(self, app):
|
||||||
|
app.register_blueprint(role_test_bp)
|
||||||
|
return app
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def role_client(self, role_app):
|
||||||
|
async with role_app.test_client() as c:
|
||||||
|
yield c
|
||||||
|
|
||||||
|
async def test_redirects_unauthenticated(self, role_client, db):
|
||||||
|
response = await role_client.get("/admin-only", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async def test_rejects_user_without_role(self, role_client, db, test_user):
|
||||||
|
async with role_client.session_transaction() as sess:
|
||||||
|
sess["user_id"] = test_user["id"]
|
||||||
|
|
||||||
|
response = await role_client.get("/admin-only", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async def test_allows_user_with_matching_role(self, role_client, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
async with role_client.session_transaction() as sess:
|
||||||
|
sess["user_id"] = test_user["id"]
|
||||||
|
|
||||||
|
response = await role_client.get("/admin-only")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_multi_role_allows_any_match(self, role_client, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "editor")
|
||||||
|
async with role_client.session_transaction() as sess:
|
||||||
|
sess["user_id"] = test_user["id"]
|
||||||
|
|
||||||
|
response = await role_client.get("/multi-role")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_multi_role_rejects_none(self, role_client, db, test_user):
|
||||||
|
await grant_role(test_user["id"], "viewer")
|
||||||
|
async with role_client.session_transaction() as sess:
|
||||||
|
sess["user_id"] = test_user["id"]
|
||||||
|
|
||||||
|
response = await role_client.get("/multi-role", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# Admin route protection
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestAdminRouteProtection:
|
||||||
|
async def test_admin_index_requires_admin_role(self, auth_client, db):
|
||||||
|
response = await auth_client.get("/admin/", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async def test_admin_index_accessible_with_admin_role(self, admin_client, db):
|
||||||
|
response = await admin_client.get("/admin/")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
async def test_admin_users_requires_admin_role(self, auth_client, db):
|
||||||
|
response = await auth_client.get("/admin/users", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async def test_admin_tasks_requires_admin_role(self, auth_client, db):
|
||||||
|
response = await auth_client.get("/admin/tasks", follow_redirects=False)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
# Impersonation
|
||||||
|
# ════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
class TestImpersonation:
|
||||||
|
async def test_impersonate_stores_admin_id(self, admin_client, db, test_user):
|
||||||
|
"""Impersonating stores admin's user_id in session['admin_impersonating']."""
|
||||||
|
# Create a second user to impersonate
|
||||||
|
now = "2025-01-01T00:00:00"
|
||||||
|
other_id = await core.execute(
|
||||||
|
"INSERT INTO users (email, name, created_at) VALUES (?, ?, ?)",
|
||||||
|
("other@example.com", "Other", now),
|
||||||
|
)
|
||||||
|
|
||||||
|
async with admin_client.session_transaction() as sess:
|
||||||
|
sess["csrf_token"] = "test_csrf"
|
||||||
|
|
||||||
|
response = await admin_client.post(
|
||||||
|
f"/admin/users/{other_id}/impersonate",
|
||||||
|
form={"csrf_token": "test_csrf"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async with admin_client.session_transaction() as sess:
|
||||||
|
assert sess["user_id"] == other_id
|
||||||
|
assert sess["admin_impersonating"] == test_user["id"]
|
||||||
|
|
||||||
|
async def test_stop_impersonating_restores_admin(self, app, db, test_user, grant_role):
|
||||||
|
"""Stopping impersonation restores the admin's user_id."""
|
||||||
|
await grant_role(test_user["id"], "admin")
|
||||||
|
|
||||||
|
async with app.test_client() as c:
|
||||||
|
async with c.session_transaction() as sess:
|
||||||
|
sess["user_id"] = 999 # impersonated user
|
||||||
|
sess["admin_impersonating"] = test_user["id"]
|
||||||
|
sess["csrf_token"] = "test_csrf"
|
||||||
|
|
||||||
|
response = await c.post(
|
||||||
|
"/admin/stop-impersonating",
|
||||||
|
form={"csrf_token": "test_csrf"},
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
assert response.status_code in (302, 303, 307)
|
||||||
|
|
||||||
|
async with c.session_transaction() as sess:
|
||||||
|
assert sess["user_id"] == test_user["id"]
|
||||||
|
assert "admin_impersonating" not in sess
|
||||||
Reference in New Issue
Block a user