Merge branch 'copier-update'
This commit is contained in:
@@ -12,24 +12,24 @@ from .core import close_db, config, get_csrf_token, init_db, setup_request_id
|
||||
|
||||
def create_app() -> Quart:
|
||||
"""Create and configure the Quart application."""
|
||||
|
||||
|
||||
# Get package directory for templates
|
||||
pkg_dir = Path(__file__).parent
|
||||
|
||||
|
||||
app = Quart(
|
||||
__name__,
|
||||
template_folder=str(pkg_dir / "templates"),
|
||||
static_folder=str(pkg_dir / "static"),
|
||||
)
|
||||
|
||||
|
||||
app.secret_key = config.SECRET_KEY
|
||||
|
||||
|
||||
# Session config
|
||||
app.config["SESSION_COOKIE_SECURE"] = not config.DEBUG
|
||||
app.config["SESSION_COOKIE_HTTPONLY"] = True
|
||||
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
||||
app.config["PERMANENT_SESSION_LIFETIME"] = 60 * 60 * 24 * config.SESSION_LIFETIME_DAYS
|
||||
|
||||
|
||||
# Database lifecycle
|
||||
@app.before_serving
|
||||
async def startup():
|
||||
@@ -41,7 +41,7 @@ def create_app() -> Quart:
|
||||
async def shutdown():
|
||||
close_analytics_db()
|
||||
await close_db()
|
||||
|
||||
|
||||
# Security headers
|
||||
@app.after_request
|
||||
async def add_security_headers(response):
|
||||
@@ -51,16 +51,42 @@ def create_app() -> Quart:
|
||||
if not config.DEBUG:
|
||||
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
||||
return response
|
||||
|
||||
# Load current user before each request
|
||||
|
||||
# Load current user + subscription + roles before each request
|
||||
@app.before_request
|
||||
async def load_user():
|
||||
g.user = None
|
||||
g.subscription = None
|
||||
user_id = session.get("user_id")
|
||||
if user_id:
|
||||
from .auth.routes import get_user_by_id
|
||||
g.user = await get_user_by_id(user_id)
|
||||
|
||||
from .core import fetch_one as _fetch_one
|
||||
row = await _fetch_one(
|
||||
"""SELECT u.*,
|
||||
bc.provider_customer_id,
|
||||
(SELECT GROUP_CONCAT(role) FROM user_roles WHERE user_id = u.id) AS roles_csv,
|
||||
s.id AS sub_id, s.plan, s.status AS sub_status,
|
||||
s.provider_subscription_id, s.current_period_end
|
||||
FROM users u
|
||||
LEFT JOIN billing_customers bc ON bc.user_id = u.id
|
||||
LEFT JOIN subscriptions s ON s.id = (
|
||||
SELECT id FROM subscriptions
|
||||
WHERE user_id = u.id
|
||||
ORDER BY created_at DESC LIMIT 1
|
||||
)
|
||||
WHERE u.id = ? AND u.deleted_at IS NULL""",
|
||||
(user_id,),
|
||||
)
|
||||
if row:
|
||||
g.user = dict(row)
|
||||
g.user["roles"] = row["roles_csv"].split(",") if row["roles_csv"] else []
|
||||
if row["sub_id"]:
|
||||
g.subscription = {
|
||||
"id": row["sub_id"], "plan": row["plan"],
|
||||
"status": row["sub_status"],
|
||||
"provider_subscription_id": row["provider_subscription_id"],
|
||||
"current_period_end": row["current_period_end"],
|
||||
}
|
||||
|
||||
# Template context globals
|
||||
@app.context_processor
|
||||
def inject_globals():
|
||||
@@ -68,10 +94,14 @@ def create_app() -> Quart:
|
||||
return {
|
||||
"config": config,
|
||||
"user": g.get("user"),
|
||||
"subscription": g.get("subscription"),
|
||||
"is_admin": "admin" in (g.get("user") or {}).get("roles", []),
|
||||
"now": datetime.utcnow(),
|
||||
"csrf_token": get_csrf_token,
|
||||
"ab_variant": getattr(g, "ab_variant", None),
|
||||
"ab_tag": getattr(g, "ab_tag", None),
|
||||
}
|
||||
|
||||
|
||||
# Health check
|
||||
@app.route("/health")
|
||||
async def health():
|
||||
@@ -94,7 +124,7 @@ def create_app() -> Quart:
|
||||
result["duckdb"] = "not configured"
|
||||
status_code = 200 if result["status"] == "healthy" else 500
|
||||
return result, status_code
|
||||
|
||||
|
||||
# Register blueprints
|
||||
from .admin.routes import bp as admin_bp
|
||||
from .api.routes import bp as api_bp
|
||||
@@ -102,17 +132,17 @@ def create_app() -> Quart:
|
||||
from .billing.routes import bp as billing_bp
|
||||
from .dashboard.routes import bp as dashboard_bp
|
||||
from .public.routes import bp as public_bp
|
||||
|
||||
|
||||
app.register_blueprint(public_bp)
|
||||
app.register_blueprint(auth_bp)
|
||||
app.register_blueprint(dashboard_bp)
|
||||
app.register_blueprint(billing_bp)
|
||||
app.register_blueprint(api_bp, url_prefix="/api/v1")
|
||||
app.register_blueprint(admin_bp)
|
||||
|
||||
|
||||
# Request ID tracking
|
||||
setup_request_id(app)
|
||||
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ async def get_valid_token(token: str) -> dict | None:
|
||||
"""Get token if valid and not expired."""
|
||||
return await fetch_one(
|
||||
"""
|
||||
SELECT at.*, u.email
|
||||
SELECT at.*, u.email
|
||||
FROM auth_tokens at
|
||||
JOIN users u ON u.id = at.user_id
|
||||
WHERE at.token = ? AND at.expires_at > ? AND at.used_at IS NULL
|
||||
@@ -88,19 +88,6 @@ async def mark_token_used(token_id: int) -> None:
|
||||
)
|
||||
|
||||
|
||||
async def get_user_with_subscription(user_id: int) -> dict | None:
|
||||
"""Get user with their active subscription info."""
|
||||
return await fetch_one(
|
||||
"""
|
||||
SELECT u.*, s.plan, s.status as sub_status, s.current_period_end
|
||||
FROM users u
|
||||
LEFT JOIN subscriptions s ON s.user_id = u.id AND s.status = 'active'
|
||||
WHERE u.id = ? AND u.deleted_at IS NULL
|
||||
""",
|
||||
(user_id,)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Decorators
|
||||
# =============================================================================
|
||||
@@ -116,24 +103,69 @@ def login_required(f):
|
||||
return decorated
|
||||
|
||||
|
||||
def subscription_required(plans: list[str] = None):
|
||||
"""Require active subscription, optionally of specific plan(s)."""
|
||||
def role_required(*roles):
|
||||
"""Require user to have at least one of the given roles."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def decorated(*args, **kwargs):
|
||||
if not g.get("user"):
|
||||
await flash("Please sign in to continue.", "warning")
|
||||
return redirect(url_for("auth.login", next=request.path))
|
||||
user_roles = g.user.get("roles", [])
|
||||
if not any(r in user_roles for r in roles):
|
||||
await flash("You don't have permission to access that page.", "error")
|
||||
return redirect(url_for("dashboard.index"))
|
||||
return await f(*args, **kwargs)
|
||||
return decorated
|
||||
return decorator
|
||||
|
||||
|
||||
async def grant_role(user_id: int, role: str) -> None:
|
||||
"""Grant a role to a user (idempotent)."""
|
||||
await execute(
|
||||
"INSERT OR IGNORE INTO user_roles (user_id, role) VALUES (?, ?)",
|
||||
(user_id, role),
|
||||
)
|
||||
|
||||
|
||||
async def revoke_role(user_id: int, role: str) -> None:
|
||||
"""Revoke a role from a user."""
|
||||
await execute(
|
||||
"DELETE FROM user_roles WHERE user_id = ? AND role = ?",
|
||||
(user_id, role),
|
||||
)
|
||||
|
||||
|
||||
async def ensure_admin_role(user_id: int, email: str) -> None:
|
||||
"""Grant admin role if email is in ADMIN_EMAILS."""
|
||||
if email.lower() in config.ADMIN_EMAILS:
|
||||
await grant_role(user_id, "admin")
|
||||
|
||||
|
||||
def subscription_required(
|
||||
plans: list[str] = None,
|
||||
allowed: tuple[str, ...] = ("active", "on_trial", "cancelled"),
|
||||
):
|
||||
"""Require active subscription, optionally of specific plan(s) and/or statuses.
|
||||
|
||||
Reads from g.subscription (eager-loaded in load_user) — zero extra queries.
|
||||
"""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def decorated(*args, **kwargs):
|
||||
if not g.get("user"):
|
||||
await flash("Please sign in to continue.", "warning")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
user = await get_user_with_subscription(g.user["id"])
|
||||
if not user or not user.get("plan"):
|
||||
|
||||
sub = g.get("subscription")
|
||||
if not sub or sub["status"] not in allowed:
|
||||
await flash("Please subscribe to access this feature.", "warning")
|
||||
return redirect(url_for("billing.pricing"))
|
||||
|
||||
if plans and user["plan"] not in plans:
|
||||
|
||||
if plans and sub["plan"] not in plans:
|
||||
await flash(f"This feature requires a {' or '.join(plans)} plan.", "warning")
|
||||
return redirect(url_for("billing.pricing"))
|
||||
|
||||
|
||||
return await f(*args, **kwargs)
|
||||
return decorated
|
||||
return decorator
|
||||
@@ -149,33 +181,33 @@ async def login():
|
||||
"""Login page - request magic link."""
|
||||
if g.get("user"):
|
||||
return redirect(url_for("dashboard.index"))
|
||||
|
||||
|
||||
if request.method == "POST":
|
||||
form = await request.form
|
||||
email = form.get("email", "").strip().lower()
|
||||
|
||||
|
||||
if not email or "@" not in email:
|
||||
await flash("Please enter a valid email address.", "error")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
|
||||
# Get or create user
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
user_id = await create_user(email)
|
||||
else:
|
||||
user_id = user["id"]
|
||||
|
||||
|
||||
# Create magic link token
|
||||
token = secrets.token_urlsafe(32)
|
||||
await create_auth_token(user_id, token)
|
||||
|
||||
|
||||
# Queue email
|
||||
from ..worker import enqueue
|
||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||
|
||||
|
||||
await flash("Check your email for the sign-in link!", "success")
|
||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||
|
||||
|
||||
return await render_template("login.html")
|
||||
|
||||
|
||||
@@ -185,39 +217,39 @@ async def signup():
|
||||
"""Signup page - same as login but with different messaging."""
|
||||
if g.get("user"):
|
||||
return redirect(url_for("dashboard.index"))
|
||||
|
||||
|
||||
plan = request.args.get("plan", "free")
|
||||
|
||||
|
||||
if request.method == "POST":
|
||||
form = await request.form
|
||||
email = form.get("email", "").strip().lower()
|
||||
selected_plan = form.get("plan", "free")
|
||||
|
||||
|
||||
if not email or "@" not in email:
|
||||
await flash("Please enter a valid email address.", "error")
|
||||
return redirect(url_for("auth.signup", plan=selected_plan))
|
||||
|
||||
|
||||
# Check if user exists
|
||||
user = await get_user_by_email(email)
|
||||
if user:
|
||||
await flash("Account already exists. Please sign in.", "info")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
|
||||
# Create user
|
||||
user_id = await create_user(email)
|
||||
|
||||
|
||||
# Create magic link token
|
||||
token = secrets.token_urlsafe(32)
|
||||
await create_auth_token(user_id, token)
|
||||
|
||||
|
||||
# Queue emails
|
||||
from ..worker import enqueue
|
||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||
await enqueue("send_welcome", {"email": email})
|
||||
|
||||
|
||||
await flash("Check your email to complete signup!", "success")
|
||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||
|
||||
|
||||
return await render_template("signup.html", plan=plan)
|
||||
|
||||
|
||||
@@ -225,29 +257,32 @@ async def signup():
|
||||
async def verify():
|
||||
"""Verify magic link token."""
|
||||
token = request.args.get("token")
|
||||
|
||||
|
||||
if not token:
|
||||
await flash("Invalid or expired link.", "error")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
|
||||
token_data = await get_valid_token(token)
|
||||
|
||||
|
||||
if not token_data:
|
||||
await flash("Invalid or expired link. Please request a new one.", "error")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
|
||||
# Mark token as used
|
||||
await mark_token_used(token_data["id"])
|
||||
|
||||
|
||||
# Update last login
|
||||
await update_user(token_data["user_id"], last_login_at=datetime.utcnow().isoformat())
|
||||
|
||||
|
||||
# Set session
|
||||
session.permanent = True
|
||||
session["user_id"] = token_data["user_id"]
|
||||
|
||||
|
||||
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||
await ensure_admin_role(token_data["user_id"], token_data["email"])
|
||||
|
||||
await flash("Successfully signed in!", "success")
|
||||
|
||||
|
||||
# Redirect to intended page or dashboard
|
||||
next_url = request.args.get("next", url_for("dashboard.index"))
|
||||
return redirect(next_url)
|
||||
@@ -274,18 +309,21 @@ async def dev_login():
|
||||
"""Instant login for development. Only works in DEBUG mode."""
|
||||
if not config.DEBUG:
|
||||
return "Not available", 404
|
||||
|
||||
|
||||
email = request.args.get("email", "dev@localhost")
|
||||
|
||||
|
||||
user = await get_user_by_email(email)
|
||||
if not user:
|
||||
user_id = await create_user(email)
|
||||
else:
|
||||
user_id = user["id"]
|
||||
|
||||
|
||||
session.permanent = True
|
||||
session["user_id"] = user_id
|
||||
|
||||
|
||||
# Auto-grant admin role if email is in ADMIN_EMAILS
|
||||
await ensure_admin_role(user_id, email)
|
||||
|
||||
await flash(f"Dev login as {email}", "success")
|
||||
return redirect(url_for("dashboard.index"))
|
||||
|
||||
@@ -296,19 +334,19 @@ async def resend():
|
||||
"""Resend magic link."""
|
||||
form = await request.form
|
||||
email = form.get("email", "").strip().lower()
|
||||
|
||||
|
||||
if not email:
|
||||
await flash("Email address required.", "error")
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
|
||||
user = await get_user_by_email(email)
|
||||
if user:
|
||||
token = secrets.token_urlsafe(32)
|
||||
await create_auth_token(user["id"], token)
|
||||
|
||||
|
||||
from ..worker import enqueue
|
||||
await enqueue("send_magic_link", {"email": email, "token": token})
|
||||
|
||||
|
||||
# Always show success (don't reveal if email exists)
|
||||
await flash("If that email is registered, we've sent a new link.", "success")
|
||||
return redirect(url_for("auth.magic_link_sent", email=email))
|
||||
|
||||
@@ -2,16 +2,19 @@
|
||||
Core infrastructure: database, config, email, and shared utilities.
|
||||
"""
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
import aiosqlite
|
||||
import resend
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
from functools import wraps
|
||||
from datetime import datetime, timedelta
|
||||
from contextvars import ContextVar
|
||||
from quart import request, session, g
|
||||
from quart import g, make_response, request, session
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# web/.env is three levels up from web/src/beanflows/core.py
|
||||
@@ -26,27 +29,35 @@ class Config:
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production")
|
||||
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:5001")
|
||||
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"
|
||||
|
||||
|
||||
DATABASE_PATH: str = os.getenv("DATABASE_PATH", "data/app.db")
|
||||
|
||||
|
||||
MAGIC_LINK_EXPIRY_MINUTES: int = int(os.getenv("MAGIC_LINK_EXPIRY_MINUTES", "15"))
|
||||
SESSION_LIFETIME_DAYS: int = int(os.getenv("SESSION_LIFETIME_DAYS", "30"))
|
||||
|
||||
|
||||
PAYMENT_PROVIDER: str = "paddle"
|
||||
|
||||
|
||||
PADDLE_API_KEY: str = os.getenv("PADDLE_API_KEY", "")
|
||||
PADDLE_WEBHOOK_SECRET: str = os.getenv("PADDLE_WEBHOOK_SECRET", "")
|
||||
PADDLE_ENVIRONMENT: str = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||
PADDLE_PRICES: dict = {
|
||||
"starter": os.getenv("PADDLE_PRICE_STARTER", ""),
|
||||
"pro": os.getenv("PADDLE_PRICE_PRO", ""),
|
||||
}
|
||||
|
||||
|
||||
UMAMI_SCRIPT_URL: str = os.getenv("UMAMI_SCRIPT_URL", "")
|
||||
UMAMI_WEBSITE_ID: str = os.getenv("UMAMI_WEBSITE_ID", "")
|
||||
|
||||
RESEND_API_KEY: str = os.getenv("RESEND_API_KEY", "")
|
||||
EMAIL_FROM: str = os.getenv("EMAIL_FROM", "hello@example.com")
|
||||
|
||||
|
||||
ADMIN_EMAILS: list[str] = [
|
||||
e.strip().lower() for e in os.getenv("ADMIN_EMAILS", "").split(",") if e.strip()
|
||||
]
|
||||
|
||||
RATE_LIMIT_REQUESTS: int = int(os.getenv("RATE_LIMIT_REQUESTS", "100"))
|
||||
RATE_LIMIT_WINDOW: int = int(os.getenv("RATE_LIMIT_WINDOW", "60"))
|
||||
|
||||
|
||||
PLAN_FEATURES: dict = {
|
||||
"free": ["dashboard", "coffee_only", "limited_history"],
|
||||
"starter": ["dashboard", "coffee_only", "full_history", "export", "api"],
|
||||
@@ -74,10 +85,10 @@ async def init_db(path: str = None) -> None:
|
||||
global _db
|
||||
db_path = path or config.DATABASE_PATH
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
_db = await aiosqlite.connect(db_path)
|
||||
_db.row_factory = aiosqlite.Row
|
||||
|
||||
|
||||
await _db.execute("PRAGMA journal_mode=WAL")
|
||||
await _db.execute("PRAGMA foreign_keys=ON")
|
||||
await _db.execute("PRAGMA busy_timeout=5000")
|
||||
@@ -137,11 +148,11 @@ async def execute_many(sql: str, params_list: list[tuple]) -> None:
|
||||
|
||||
class transaction:
|
||||
"""Async context manager for transactions."""
|
||||
|
||||
|
||||
async def __aenter__(self):
|
||||
self.db = await get_db()
|
||||
return self.db
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
await self.db.commit()
|
||||
@@ -153,25 +164,32 @@ class transaction:
|
||||
# Email
|
||||
# =============================================================================
|
||||
|
||||
async def send_email(to: str, subject: str, html: str, text: str = None) -> bool:
|
||||
"""Send email via Resend API."""
|
||||
EMAIL_ADDRESSES = {
|
||||
"transactional": f"{config.APP_NAME} <{config.EMAIL_FROM}>",
|
||||
}
|
||||
|
||||
|
||||
async def send_email(
|
||||
to: str, subject: str, html: str, text: str = None, from_addr: str = None
|
||||
) -> bool:
|
||||
"""Send email via Resend SDK."""
|
||||
if not config.RESEND_API_KEY:
|
||||
print(f"[EMAIL] Would send to {to}: {subject}")
|
||||
return True
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
"https://api.resend.com/emails",
|
||||
headers={"Authorization": f"Bearer {config.RESEND_API_KEY}"},
|
||||
json={
|
||||
"from": config.EMAIL_FROM,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
"text": text or html,
|
||||
},
|
||||
)
|
||||
return response.status_code == 200
|
||||
|
||||
resend.api_key = config.RESEND_API_KEY
|
||||
try:
|
||||
resend.Emails.send({
|
||||
"from": from_addr or config.EMAIL_FROM,
|
||||
"to": to,
|
||||
"subject": subject,
|
||||
"html": html,
|
||||
"text": text or html,
|
||||
})
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[EMAIL] Error sending to {to}: {e}")
|
||||
return False
|
||||
|
||||
# =============================================================================
|
||||
# CSRF Protection
|
||||
@@ -214,34 +232,34 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t
|
||||
window = window or config.RATE_LIMIT_WINDOW
|
||||
now = datetime.utcnow()
|
||||
window_start = now - timedelta(seconds=window)
|
||||
|
||||
|
||||
# Clean old entries and count recent
|
||||
await execute(
|
||||
"DELETE FROM rate_limits WHERE key = ? AND timestamp < ?",
|
||||
(key, window_start.isoformat())
|
||||
)
|
||||
|
||||
|
||||
result = await fetch_one(
|
||||
"SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?",
|
||||
(key, window_start.isoformat())
|
||||
)
|
||||
count = result["count"] if result else 0
|
||||
|
||||
|
||||
info = {
|
||||
"limit": limit,
|
||||
"remaining": max(0, limit - count - 1),
|
||||
"reset": int((window_start + timedelta(seconds=window)).timestamp()),
|
||||
}
|
||||
|
||||
|
||||
if count >= limit:
|
||||
return False, info
|
||||
|
||||
|
||||
# Record this request
|
||||
await execute(
|
||||
"INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)",
|
||||
(key, now.isoformat())
|
||||
)
|
||||
|
||||
|
||||
return True, info
|
||||
|
||||
|
||||
@@ -254,13 +272,13 @@ def rate_limit(limit: int = None, window: int = None, key_func=None):
|
||||
key = key_func()
|
||||
else:
|
||||
key = f"ip:{request.remote_addr}"
|
||||
|
||||
|
||||
allowed, info = await check_rate_limit(key, limit, window)
|
||||
|
||||
|
||||
if not allowed:
|
||||
response = {"error": "Rate limit exceeded", **info}
|
||||
return response, 429
|
||||
|
||||
|
||||
return await f(*args, **kwargs)
|
||||
return decorated
|
||||
return decorator
|
||||
@@ -284,7 +302,7 @@ def setup_request_id(app):
|
||||
rid = request.headers.get("X-Request-ID") or secrets.token_hex(8)
|
||||
request_id_var.set(rid)
|
||||
g.request_id = rid
|
||||
|
||||
|
||||
@app.after_request
|
||||
async def add_request_id_header(response):
|
||||
response.headers["X-Request-ID"] = get_request_id()
|
||||
@@ -294,13 +312,11 @@ def setup_request_id(app):
|
||||
# Webhook Signature Verification
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool:
|
||||
"""Verify HMAC-SHA256 webhook signature."""
|
||||
expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest()
|
||||
return hmac.compare_digest(signature, expected)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Soft Delete Helpers
|
||||
# =============================================================================
|
||||
@@ -336,3 +352,27 @@ async def purge_deleted(table: str, days: int = 30) -> int:
|
||||
f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?",
|
||||
(cutoff,)
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# A/B Testing
|
||||
# =============================================================================
|
||||
|
||||
def ab_test(experiment: str, variants: tuple = ("control", "treatment")):
|
||||
"""Assign visitor to an A/B test variant via cookie, tag Umami pageviews."""
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
async def wrapper(*args, **kwargs):
|
||||
cookie_key = f"ab_{experiment}"
|
||||
assigned = request.cookies.get(cookie_key)
|
||||
if assigned not in variants:
|
||||
assigned = random.choice(variants)
|
||||
|
||||
g.ab_variant = assigned
|
||||
g.ab_tag = f"{experiment}-{assigned}"
|
||||
|
||||
response = await make_response(await f(*args, **kwargs))
|
||||
response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60)
|
||||
return response
|
||||
return wrapper
|
||||
return decorator
|
||||
@@ -1,51 +1,95 @@
|
||||
"""
|
||||
Simple migration runner. Runs schema.sql against the database.
|
||||
"""
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
import os
|
||||
import sys
|
||||
Sequential migration runner.
|
||||
|
||||
Replays all migrations in order. All databases — fresh and existing —
|
||||
go through the same path. No schema.sql fast-path.
|
||||
|
||||
- Scans versions/ for NNNN_*.py files and runs unapplied ones in order
|
||||
- Each migration has an up(conn) function receiving an uncommitted connection
|
||||
- All pending migrations share a single transaction (batch atomicity)
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent to path for imports
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
VERSIONS_DIR = Path(__file__).parent / "versions"
|
||||
VERSION_RE = re.compile(r"^(\d{4})_.+\.py$")
|
||||
|
||||
def migrate():
|
||||
"""Run migrations."""
|
||||
# Get database path from env or default
|
||||
db_path = os.getenv("DATABASE_PATH", "data/app.db")
|
||||
|
||||
# Ensure directory exists
|
||||
# Derived from the package path: …/src/<slug>/migrations/migrate.py
|
||||
_PACKAGE = Path(__file__).parent.parent.name # e.g. "myproject"
|
||||
|
||||
|
||||
def _discover_versions():
|
||||
"""Return sorted list of version file stems."""
|
||||
if not VERSIONS_DIR.is_dir():
|
||||
return []
|
||||
versions = []
|
||||
for f in sorted(VERSIONS_DIR.iterdir()):
|
||||
if VERSION_RE.match(f.name):
|
||||
versions.append(f.stem)
|
||||
return versions
|
||||
|
||||
|
||||
def migrate(db_path=None):
|
||||
if db_path is None:
|
||||
db_path = os.getenv("DATABASE_PATH", "data/app.db")
|
||||
Path(db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Read schema
|
||||
schema_path = Path(__file__).parent / "schema.sql"
|
||||
schema = schema_path.read_text()
|
||||
|
||||
# Connect and execute
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
# Enable WAL mode
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
# Run schema
|
||||
conn.executescript(schema)
|
||||
|
||||
# Ensure tracking table exists before anything else
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
print(f"✓ Migrations complete: {db_path}")
|
||||
|
||||
# Show tables
|
||||
|
||||
versions = _discover_versions()
|
||||
applied = {
|
||||
row[0]
|
||||
for row in conn.execute("SELECT name FROM _migrations").fetchall()
|
||||
}
|
||||
pending = [v for v in versions if v not in applied]
|
||||
|
||||
if pending:
|
||||
for name in pending:
|
||||
print(f" Applying {name}...")
|
||||
mod = importlib.import_module(
|
||||
f"{_PACKAGE}.migrations.versions.{name}"
|
||||
)
|
||||
mod.up(conn)
|
||||
conn.execute(
|
||||
"INSERT INTO _migrations (name) VALUES (?)", (name,)
|
||||
)
|
||||
conn.commit()
|
||||
print(f"✓ Applied {len(pending)} migration(s): {db_path}")
|
||||
else:
|
||||
print(f"✓ All migrations already applied: {db_path}")
|
||||
|
||||
# Show tables (excluding internal sqlite/fts tables)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
" AND name NOT LIKE 'sqlite_%'"
|
||||
" ORDER BY name"
|
||||
)
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
print(f" Tables: {', '.join(tables)}")
|
||||
|
||||
|
||||
conn.close()
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
-- BeanFlows Database Schema
|
||||
-- Run with: python -m beanflows.migrations.migrate
|
||||
|
||||
-- Migration tracking
|
||||
CREATE TABLE IF NOT EXISTS _migrations (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT UNIQUE NOT NULL,
|
||||
applied_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
-- Users
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -28,25 +35,58 @@ CREATE TABLE IF NOT EXISTS auth_tokens (
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_token ON auth_tokens(token);
|
||||
CREATE INDEX IF NOT EXISTS idx_auth_tokens_user ON auth_tokens(user_id);
|
||||
|
||||
-- User Roles
|
||||
CREATE TABLE IF NOT EXISTS user_roles (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
role TEXT NOT NULL,
|
||||
granted_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
UNIQUE(user_id, role)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_roles_user ON user_roles(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_roles_role ON user_roles(role);
|
||||
|
||||
-- Billing Customers (payment provider identity, separate from subscriptions)
|
||||
CREATE TABLE IF NOT EXISTS billing_customers (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
||||
provider_customer_id TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_customers_user ON billing_customers(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_customers_provider ON billing_customers(provider_customer_id);
|
||||
|
||||
-- Subscriptions
|
||||
CREATE TABLE IF NOT EXISTS subscriptions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL UNIQUE REFERENCES users(id),
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
plan TEXT NOT NULL DEFAULT 'free',
|
||||
status TEXT NOT NULL DEFAULT 'free',
|
||||
|
||||
paddle_customer_id TEXT,
|
||||
paddle_subscription_id TEXT,
|
||||
|
||||
provider_subscription_id TEXT,
|
||||
current_period_end TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_user ON subscriptions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(provider_subscription_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_subscriptions_provider ON subscriptions(paddle_subscription_id);
|
||||
-- Transactions
|
||||
CREATE TABLE IF NOT EXISTS transactions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
subscription_id INTEGER REFERENCES subscriptions(id),
|
||||
provider_transaction_id TEXT UNIQUE,
|
||||
type TEXT NOT NULL DEFAULT 'payment',
|
||||
amount_cents INTEGER,
|
||||
currency TEXT DEFAULT 'USD',
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
metadata TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_transactions_user ON transactions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_transactions_provider ON transactions(provider_transaction_id);
|
||||
|
||||
-- API Keys
|
||||
CREATE TABLE IF NOT EXISTS api_keys (
|
||||
@@ -99,3 +139,39 @@ CREATE TABLE IF NOT EXISTS tasks (
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_tasks_status ON tasks(status, run_at);
|
||||
|
||||
-- Items (example domain entity - replace with your domain)
|
||||
CREATE TABLE IF NOT EXISTS items (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL REFERENCES users(id),
|
||||
name TEXT NOT NULL,
|
||||
data TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT,
|
||||
deleted_at TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_items_user ON items(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_items_deleted ON items(deleted_at);
|
||||
|
||||
-- Full-text search for items (optional)
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS items_fts USING fts5(
|
||||
name,
|
||||
data,
|
||||
content='items',
|
||||
content_rowid='id'
|
||||
);
|
||||
|
||||
-- FTS triggers
|
||||
CREATE TRIGGER IF NOT EXISTS items_ai AFTER INSERT ON items BEGIN
|
||||
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS items_ad AFTER DELETE ON items BEGIN
|
||||
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||
END;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS items_au AFTER UPDATE ON items BEGIN
|
||||
INSERT INTO items_fts(items_fts, rowid, name, data) VALUES('delete', old.id, old.name, old.data);
|
||||
INSERT INTO items_fts(rowid, name, data) VALUES (new.id, new.name, new.data);
|
||||
END;
|
||||
0
web/src/beanflows/scripts/__init__.py
Normal file
0
web/src/beanflows/scripts/__init__.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
92
web/src/beanflows/scripts/setup_paddle.py
Normal file
@@ -0,0 +1,92 @@
|
||||
|
||||
"""
|
||||
Create Paddle products and prices for BeanFlows.
|
||||
|
||||
Run once per environment (sandbox, then production).
|
||||
Prints resulting price IDs as a .env snippet.
|
||||
|
||||
Usage:
|
||||
uv run python -m beanflows.scripts.setup_paddle
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from paddle_billing import Client as PaddleClient
|
||||
from paddle_billing import Environment, Options
|
||||
from paddle_billing.Entities.Shared import CurrencyCode, Money, TaxCategory
|
||||
from paddle_billing.Resources.Prices.Operations import CreatePrice
|
||||
from paddle_billing.Resources.Products.Operations import CreateProduct
|
||||
|
||||
load_dotenv()
|
||||
|
||||
PADDLE_API_KEY = os.getenv("PADDLE_API_KEY", "")
|
||||
PADDLE_ENVIRONMENT = os.getenv("PADDLE_ENVIRONMENT", "sandbox")
|
||||
|
||||
if not PADDLE_API_KEY:
|
||||
print("ERROR: Set PADDLE_API_KEY in .env first")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
PRODUCTS = [
|
||||
# Subscriptions
|
||||
{
|
||||
"name": "Starter",
|
||||
"env_key": "PADDLE_PRICE_STARTER",
|
||||
"price": 900,
|
||||
"currency": CurrencyCode.USD,
|
||||
"interval": "month",
|
||||
"type": "subscription",
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
"env_key": "PADDLE_PRICE_PRO",
|
||||
"price": 2900,
|
||||
"currency": CurrencyCode.USD,
|
||||
"interval": "month",
|
||||
"type": "subscription",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def main():
|
||||
env = Environment.SANDBOX if PADDLE_ENVIRONMENT == "sandbox" else Environment.PRODUCTION
|
||||
paddle = PaddleClient(PADDLE_API_KEY, options=Options(env))
|
||||
|
||||
print(f"Creating products in {PADDLE_ENVIRONMENT}...\n")
|
||||
|
||||
env_lines = []
|
||||
|
||||
for spec in PRODUCTS:
|
||||
# Create product
|
||||
product = paddle.products.create(CreateProduct(
|
||||
name=spec["name"],
|
||||
tax_category=TaxCategory.Standard,
|
||||
))
|
||||
print(f" Product: {spec['name']} -> {product.id}")
|
||||
|
||||
# Create price
|
||||
price_kwargs = {
|
||||
"description": spec["name"],
|
||||
"product_id": product.id,
|
||||
"unit_price": Money(str(spec["price"]), spec["currency"]),
|
||||
}
|
||||
|
||||
if spec["type"] == "subscription":
|
||||
from paddle_billing.Entities.Shared import TimePeriod
|
||||
price_kwargs["billing_cycle"] = TimePeriod(interval="month", frequency=1)
|
||||
|
||||
price = paddle.prices.create(CreatePrice(**price_kwargs))
|
||||
print(f" Price: {spec['env_key']} = {price.id}")
|
||||
|
||||
env_lines.append(f"{spec['env_key']}={price.id}")
|
||||
|
||||
print("\n# --- .env snippet ---")
|
||||
for line in env_lines:
|
||||
print(line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-400-Regular.woff2
Normal file
Binary file not shown.
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
BIN
web/src/beanflows/static/fonts/CommitMono-700-Regular.woff2
Normal file
Binary file not shown.
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
90
web/src/beanflows/static/fonts/CommitMono-LICENSE.txt
Normal file
@@ -0,0 +1,90 @@
|
||||
This Font Software is licensed under the SIL Open Font License, Version 1.1.
|
||||
This license is copied below, and is also available with a FAQ at:
|
||||
http://scripts.sil.org/OFL
|
||||
|
||||
-----------------------------------------------------------
|
||||
SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007
|
||||
-----------------------------------------------------------
|
||||
|
||||
PREAMBLE
|
||||
The goals of the Open Font License (OFL) are to stimulate worldwide
|
||||
development of collaborative font projects, to support the font creation
|
||||
efforts of academic and linguistic communities, and to provide a free and
|
||||
open framework in which fonts may be shared and improved in partnership
|
||||
with others.
|
||||
|
||||
The OFL allows the licensed fonts to be used, studied, modified and
|
||||
redistributed freely as long as they are not sold by themselves. The
|
||||
fonts, including any derivative works, can be bundled, embedded,
|
||||
redistributed and/or sold with any software provided that any reserved
|
||||
names are not used by derivative works. The fonts and derivatives,
|
||||
however, cannot be released under any other type of license. The
|
||||
requirement for fonts to remain under this license does not apply
|
||||
to any document created using the fonts or their derivatives.
|
||||
|
||||
DEFINITIONS
|
||||
"Font Software" refers to the set of files released by the Copyright
|
||||
Holder(s) under this license and clearly marked as such. This may
|
||||
include source files, build scripts and documentation.
|
||||
|
||||
"Reserved Font Name" refers to any names specified as such after the
|
||||
copyright statement(s).
|
||||
|
||||
"Original Version" refers to the collection of Font Software components as
|
||||
distributed by the Copyright Holder(s).
|
||||
|
||||
"Modified Version" refers to any derivative made by adding to, deleting,
|
||||
or substituting -- in part or in whole -- any of the components of the
|
||||
Original Version, by changing formats or by porting the Font Software to a
|
||||
new environment.
|
||||
|
||||
"Author" refers to any designer, engineer, programmer, technical
|
||||
writer or other person who contributed to the Font Software.
|
||||
|
||||
PERMISSION & CONDITIONS
|
||||
Permission is hereby granted, free of charge, to any person obtaining
|
||||
a copy of the Font Software, to use, study, copy, merge, embed, modify,
|
||||
redistribute, and sell modified and unmodified copies of the Font
|
||||
Software, subject to the following conditions:
|
||||
|
||||
1) Neither the Font Software nor any of its individual components,
|
||||
in Original or Modified Versions, may be sold by itself.
|
||||
|
||||
2) Original or Modified Versions of the Font Software may be bundled,
|
||||
redistributed and/or sold with any software, provided that each copy
|
||||
contains the above copyright notice and this license. These can be
|
||||
included either as stand-alone text files, human-readable headers or
|
||||
in the appropriate machine-readable metadata fields within text or
|
||||
binary files as long as those fields can be easily viewed by the user.
|
||||
|
||||
3) No Modified Version of the Font Software may use the Reserved Font
|
||||
Name(s) unless explicit written permission is granted by the corresponding
|
||||
Copyright Holder. This restriction only applies to the primary font name as
|
||||
presented to the users.
|
||||
|
||||
4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font
|
||||
Software shall not be used to promote, endorse or advertise any
|
||||
Modified Version, except to acknowledge the contribution(s) of the
|
||||
Copyright Holder(s) and the Author(s) or with their explicit written
|
||||
permission.
|
||||
|
||||
5) The Font Software, modified or unmodified, in part or in whole,
|
||||
must be distributed entirely under this license, and must not be
|
||||
distributed under any other license. The requirement for fonts to
|
||||
remain under this license does not apply to any document created
|
||||
using the Font Software.
|
||||
|
||||
TERMINATION
|
||||
This license becomes null and void if any of the above conditions are
|
||||
not met.
|
||||
|
||||
DISCLAIMER
|
||||
THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT
|
||||
OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE
|
||||
COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL
|
||||
DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM
|
||||
OTHER DEALINGS IN THE FONT SOFTWARE.
|
||||
@@ -13,6 +13,46 @@ from .core import config, init_db, fetch_one, fetch_all, execute, send_email
|
||||
HANDLERS: dict[str, callable] = {}
|
||||
|
||||
|
||||
def _email_wrap(body: str) -> str:
|
||||
"""Wrap email body in a branded layout with inline CSS."""
|
||||
return f"""\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head><meta charset="utf-8"></head>
|
||||
<body style="margin:0;padding:0;background-color:#F8FAFC;font-family:'Inter',Helvetica,Arial,sans-serif;">
|
||||
<table width="100%" cellpadding="0" cellspacing="0" style="background-color:#F8FAFC;padding:40px 0;">
|
||||
<tr><td align="center">
|
||||
<table width="480" cellpadding="0" cellspacing="0" style="background-color:#FFFFFF;border-radius:8px;border:1px solid #E2E8F0;overflow:hidden;">
|
||||
<!-- Header -->
|
||||
<tr><td style="background-color:#0F172A;padding:24px 32px;">
|
||||
<span style="color:#FFFFFF;font-size:18px;font-weight:700;letter-spacing:-0.02em;">{config.APP_NAME}</span>
|
||||
</td></tr>
|
||||
<!-- Body -->
|
||||
<tr><td style="padding:32px;color:#475569;font-size:15px;line-height:1.6;">
|
||||
{body}
|
||||
</td></tr>
|
||||
<!-- Footer -->
|
||||
<tr><td style="padding:20px 32px;border-top:1px solid #E2E8F0;text-align:center;">
|
||||
<span style="color:#94A3B8;font-size:12px;">© {config.APP_NAME} · You received this because you have an account.</span>
|
||||
</td></tr>
|
||||
</table>
|
||||
</td></tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
def _email_button(url: str, label: str) -> str:
|
||||
"""Render a branded CTA button for email."""
|
||||
return (
|
||||
f'<table cellpadding="0" cellspacing="0" style="margin:24px 0;">'
|
||||
f'<tr><td style="background-color:#3B82F6;border-radius:6px;text-align:center;">'
|
||||
f'<a href="{url}" style="display:inline-block;padding:12px 28px;'
|
||||
f'color:#FFFFFF;font-size:15px;font-weight:600;text-decoration:none;">'
|
||||
f'{label}</a></td></tr></table>'
|
||||
)
|
||||
|
||||
|
||||
def task(name: str):
|
||||
"""Decorator to register a task handler."""
|
||||
def decorator(f):
|
||||
@@ -46,7 +86,7 @@ async def get_pending_tasks(limit: int = 10) -> list[dict]:
|
||||
now = datetime.utcnow().isoformat()
|
||||
return await fetch_all(
|
||||
"""
|
||||
SELECT * FROM tasks
|
||||
SELECT * FROM tasks
|
||||
WHERE status = 'pending' AND run_at <= ?
|
||||
ORDER BY run_at ASC
|
||||
LIMIT ?
|
||||
@@ -66,15 +106,15 @@ async def mark_complete(task_id: int) -> None:
|
||||
async def mark_failed(task_id: int, error: str, retries: int) -> None:
|
||||
"""Mark task as failed, schedule retry if attempts remain."""
|
||||
max_retries = 3
|
||||
|
||||
|
||||
if retries < max_retries:
|
||||
# Exponential backoff: 1min, 5min, 25min
|
||||
delay = timedelta(minutes=5 ** retries)
|
||||
run_at = datetime.utcnow() + delay
|
||||
|
||||
|
||||
await execute(
|
||||
"""
|
||||
UPDATE tasks
|
||||
UPDATE tasks
|
||||
SET status = 'pending', error = ?, retries = ?, run_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
@@ -99,6 +139,7 @@ async def handle_send_email(payload: dict) -> None:
|
||||
subject=payload["subject"],
|
||||
html=payload["html"],
|
||||
text=payload.get("text"),
|
||||
from_addr=payload.get("from_addr"),
|
||||
)
|
||||
|
||||
|
||||
@@ -106,35 +147,37 @@ async def handle_send_email(payload: dict) -> None:
|
||||
async def handle_send_magic_link(payload: dict) -> None:
|
||||
"""Send magic link email."""
|
||||
link = f"{config.BASE_URL}/auth/verify?token={payload['token']}"
|
||||
|
||||
html = f"""
|
||||
<h2>Sign in to {config.APP_NAME}</h2>
|
||||
<p>Click the link below to sign in:</p>
|
||||
<p><a href="{link}">{link}</a></p>
|
||||
<p>This link expires in {config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>
|
||||
<p>If you didn't request this, you can safely ignore this email.</p>
|
||||
"""
|
||||
|
||||
|
||||
body = (
|
||||
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Sign in to {config.APP_NAME}</h2>'
|
||||
f"<p>Click the button below to sign in. This link expires in "
|
||||
f"{config.MAGIC_LINK_EXPIRY_MINUTES} minutes.</p>"
|
||||
f"{_email_button(link, 'Sign In')}"
|
||||
f'<p style="font-size:13px;color:#94A3B8;">If the button doesn\'t work, copy and paste this URL into your browser:</p>'
|
||||
f'<p style="font-size:13px;color:#94A3B8;word-break:break-all;">{link}</p>'
|
||||
f'<p style="font-size:13px;color:#94A3B8;">If you didn\'t request this, you can safely ignore this email.</p>'
|
||||
)
|
||||
|
||||
await send_email(
|
||||
to=payload["email"],
|
||||
subject=f"Sign in to {config.APP_NAME}",
|
||||
html=html,
|
||||
html=_email_wrap(body),
|
||||
)
|
||||
|
||||
|
||||
@task("send_welcome")
|
||||
async def handle_send_welcome(payload: dict) -> None:
|
||||
"""Send welcome email to new user."""
|
||||
html = f"""
|
||||
<h2>Welcome to {config.APP_NAME}!</h2>
|
||||
<p>Thanks for signing up. We're excited to have you.</p>
|
||||
<p><a href="{config.BASE_URL}/dashboard">Go to your dashboard</a></p>
|
||||
"""
|
||||
|
||||
body = (
|
||||
f'<h2 style="margin:0 0 16px;color:#0F172A;font-size:20px;">Welcome to {config.APP_NAME}!</h2>'
|
||||
f"<p>Thanks for signing up. We're excited to have you.</p>"
|
||||
f'{_email_button(f"{config.BASE_URL}/dashboard", "Go to Dashboard")}'
|
||||
)
|
||||
|
||||
await send_email(
|
||||
to=payload["email"],
|
||||
subject=f"Welcome to {config.APP_NAME}",
|
||||
html=html,
|
||||
html=_email_wrap(body),
|
||||
)
|
||||
|
||||
|
||||
@@ -173,12 +216,12 @@ async def process_task(task: dict) -> None:
|
||||
task_name = task["task_name"]
|
||||
task_id = task["id"]
|
||||
retries = task.get("retries", 0)
|
||||
|
||||
|
||||
handler = HANDLERS.get(task_name)
|
||||
if not handler:
|
||||
await mark_failed(task_id, f"Unknown task: {task_name}", retries)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
payload = json.loads(task["payload"]) if task["payload"] else {}
|
||||
await handler(payload)
|
||||
@@ -194,17 +237,17 @@ async def run_worker(poll_interval: float = 1.0) -> None:
|
||||
"""Main worker loop."""
|
||||
print("[WORKER] Starting...")
|
||||
await init_db()
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
tasks = await get_pending_tasks(limit=10)
|
||||
|
||||
|
||||
for task in tasks:
|
||||
await process_task(task)
|
||||
|
||||
|
||||
if not tasks:
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WORKER] Error: {e}")
|
||||
await asyncio.sleep(poll_interval * 5)
|
||||
@@ -214,16 +257,16 @@ async def run_scheduler() -> None:
|
||||
"""Schedule periodic cleanup tasks."""
|
||||
print("[SCHEDULER] Starting...")
|
||||
await init_db()
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Schedule cleanup tasks every hour
|
||||
await enqueue("cleanup_expired_tokens")
|
||||
await enqueue("cleanup_rate_limits")
|
||||
await enqueue("cleanup_old_tasks")
|
||||
|
||||
|
||||
await asyncio.sleep(3600) # 1 hour
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[SCHEDULER] Error: {e}")
|
||||
await asyncio.sleep(60)
|
||||
@@ -231,8 +274,8 @@ async def run_scheduler() -> None:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "scheduler":
|
||||
asyncio.run(run_scheduler())
|
||||
else:
|
||||
asyncio.run(run_worker())
|
||||
asyncio.run(run_worker())
|
||||
Reference in New Issue
Block a user