diff --git a/web/src/beanflows/core.py b/web/src/beanflows/core.py index e9b5c0d..c404a88 100644 --- a/web/src/beanflows/core.py +++ b/web/src/beanflows/core.py @@ -1,29 +1,33 @@ """ Core infrastructure: database, config, email, and shared utilities. """ + +import hashlib +import hmac import os import random import secrets -import hashlib -import hmac +from contextvars import ContextVar +from datetime import datetime, timedelta +from functools import wraps +from pathlib import Path import aiosqlite import resend -import httpx -from pathlib import Path -from functools import wraps -from datetime import datetime, timedelta -from contextvars import ContextVar -from quart import g, make_response, render_template, request, session from dotenv import load_dotenv +from quart import g, make_response, render_template, request, session -# web/.env is three levels up from web/src/beanflows/core.py -load_dotenv(Path(__file__).parent.parent.parent / ".env", override=False) +# Load .env from repo root first (created by `make secrets-decrypt-dev`), +# fall back to web/.env for legacy local dev setups. +_repo_root = Path(__file__).parent.parent.parent.parent +load_dotenv(_repo_root / ".env", override=False) +load_dotenv(_repo_root / "web" / ".env", override=False) # ============================================================================= # Configuration # ============================================================================= + class Config: APP_NAME: str = os.getenv("APP_NAME", "BeanFlows") SECRET_KEY: str = os.getenv("SECRET_KEY", "change-me-in-production") @@ -53,7 +57,9 @@ class Config: ADMIN_EMAILS: list[str] = [ e.strip().lower() - for e in os.getenv("ADMIN_EMAILS", "hendrik@beanflow.coffee,simon@beanflows.coffee").split(",") + for e in os.getenv("ADMIN_EMAILS", "hendrik@beanflow.coffee,simon@beanflows.coffee").split( + "," + ) if e.strip() ] @@ -66,7 +72,14 @@ class Config: PLAN_FEATURES: dict = { "free": ["dashboard", "coffee_only", "limited_history"], "starter": ["dashboard", "coffee_only", "full_history", "export", "api"], - "pro": ["dashboard", "all_commodities", "full_history", "export", "api", "priority_support"], + "pro": [ + "dashboard", + "all_commodities", + "full_history", + "export", + "api", + "priority_support", + ], } PLAN_LIMITS: dict = { @@ -165,6 +178,7 @@ class transaction: await self.db.rollback() return False + # ============================================================================= # Email # ============================================================================= @@ -175,8 +189,12 @@ EMAIL_ADDRESSES = { async def send_email( - to: str, subject: str, html: str, text: str = None, - from_addr: str = None, template: str = None, + to: str, + subject: str, + html: str, + text: str = None, + from_addr: str = None, + template: str = None, ) -> bool: """Send email via Resend SDK and log to email_log table.""" if not config.RESEND_API_KEY: @@ -191,13 +209,15 @@ async def send_email( provider_id = None error_msg = None try: - result = resend.Emails.send({ - "from": from_addr or config.EMAIL_FROM, - "to": to, - "subject": subject, - "html": html, - "text": text or html, - }) + result = resend.Emails.send( + { + "from": from_addr or config.EMAIL_FROM, + "to": to, + "subject": subject, + "html": html, + "text": text or html, + } + ) provider_id = result.get("id") if isinstance(result, dict) else None except Exception as e: error_msg = str(e) @@ -206,15 +226,24 @@ async def send_email( await execute( """INSERT INTO email_log (recipient, subject, template, status, provider_id, error, created_at) VALUES (?, ?, ?, ?, ?, ?, ?)""", - (to, subject, template, "error" if error_msg else "sent", - provider_id, error_msg, datetime.utcnow().isoformat()), + ( + to, + subject, + template, + "error" if error_msg else "sent", + provider_id, + error_msg, + datetime.utcnow().isoformat(), + ), ) return error_msg is None + # ============================================================================= # CSRF Protection # ============================================================================= + def get_csrf_token() -> str: """Get or create CSRF token for current session.""" if "csrf_token" not in session: @@ -229,6 +258,7 @@ def validate_csrf_token(token: str) -> bool: def csrf_protect(f): """Decorator to require valid CSRF token for POST requests.""" + @wraps(f) async def decorated(*args, **kwargs): if request.method == "POST": @@ -237,12 +267,15 @@ def csrf_protect(f): if not validate_csrf_token(token): return {"error": "Invalid CSRF token"}, 403 return await f(*args, **kwargs) + return decorated + # ============================================================================= # Rate Limiting (SQLite-based) # ============================================================================= + async def check_rate_limit(key: str, limit: int = None, window: int = None) -> tuple[bool, dict]: """ Check if rate limit exceeded. Returns (is_allowed, info). @@ -255,13 +288,12 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t # Clean old entries and count recent await execute( - "DELETE FROM rate_limits WHERE key = ? AND timestamp < ?", - (key, window_start.isoformat()) + "DELETE FROM rate_limits WHERE key = ? AND timestamp < ?", (key, window_start.isoformat()) ) result = await fetch_one( "SELECT COUNT(*) as count FROM rate_limits WHERE key = ? AND timestamp > ?", - (key, window_start.isoformat()) + (key, window_start.isoformat()), ) count = result["count"] if result else 0 @@ -275,16 +307,14 @@ async def check_rate_limit(key: str, limit: int = None, window: int = None) -> t return False, info # Record this request - await execute( - "INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)", - (key, now.isoformat()) - ) + await execute("INSERT INTO rate_limits (key, timestamp) VALUES (?, ?)", (key, now.isoformat())) return True, info def rate_limit(limit: int = None, window: int = None, key_func=None): """Decorator for rate limiting routes.""" + def decorator(f): @wraps(f) async def decorated(*args, **kwargs): @@ -300,9 +330,12 @@ def rate_limit(limit: int = None, window: int = None, key_func=None): return response, 429 return await f(*args, **kwargs) + return decorated + return decorator + # ============================================================================= # Request ID Tracking # ============================================================================= @@ -317,6 +350,7 @@ def get_request_id() -> str: def setup_request_id(app): """Setup request ID middleware.""" + @app.before_request async def set_request_id(): rid = request.headers.get("X-Request-ID") or secrets.token_hex(8) @@ -328,34 +362,35 @@ def setup_request_id(app): response.headers["X-Request-ID"] = get_request_id() return response + # ============================================================================= # Webhook Signature Verification # ============================================================================= + def verify_hmac_signature(payload: bytes, signature: str, secret: str) -> bool: """Verify HMAC-SHA256 webhook signature.""" expected = hmac.new(secret.encode(), payload, hashlib.sha256).hexdigest() return hmac.compare_digest(signature, expected) + # ============================================================================= # Soft Delete Helpers # ============================================================================= + async def soft_delete(table: str, id: int) -> bool: """Mark record as deleted.""" result = await execute( f"UPDATE {table} SET deleted_at = ? WHERE id = ? AND deleted_at IS NULL", - (datetime.utcnow().isoformat(), id) + (datetime.utcnow().isoformat(), id), ) return result > 0 async def restore(table: str, id: int) -> bool: """Restore soft-deleted record.""" - result = await execute( - f"UPDATE {table} SET deleted_at = NULL WHERE id = ?", - (id,) - ) + result = await execute(f"UPDATE {table} SET deleted_at = NULL WHERE id = ?", (id,)) return result > 0 @@ -369,8 +404,7 @@ async def purge_deleted(table: str, days: int = 30) -> int: """Purge records deleted more than X days ago.""" cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat() return await execute( - f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?", - (cutoff,) + f"DELETE FROM {table} WHERE deleted_at IS NOT NULL AND deleted_at < ?", (cutoff,) ) @@ -378,8 +412,10 @@ async def purge_deleted(table: str, days: int = 30) -> int: # Waitlist # ============================================================================= + def waitlist_gate(template: str, **extra_context): """Intercept GET requests when WAITLIST_MODE=true and render the waitlist template.""" + def decorator(f): @wraps(f) async def wrapper(*args, **kwargs): @@ -389,7 +425,9 @@ def waitlist_gate(template: str, **extra_context): ctx[k] = v() if callable(v) else v return await render_template(template, **ctx) return await f(*args, **kwargs) + return wrapper + return decorator @@ -411,16 +449,19 @@ async def capture_waitlist_email( if result: from .worker import enqueue + await enqueue("send_waitlist_confirmation", {"email": email}) if config.RESEND_AUDIENCE_WAITLIST and config.RESEND_API_KEY: try: resend.api_key = config.RESEND_API_KEY - resend.Contacts.create({ - "email": email, - "audience_id": config.RESEND_AUDIENCE_WAITLIST, - "unsubscribed": False, - }) + resend.Contacts.create( + { + "email": email, + "audience_id": config.RESEND_AUDIENCE_WAITLIST, + "unsubscribed": False, + } + ) except Exception as e: print(f"[WAITLIST] Resend audience error: {e}") @@ -432,8 +473,10 @@ async def capture_waitlist_email( # A/B Testing # ============================================================================= + def ab_test(experiment: str, variants: tuple = ("control", "treatment")): """Assign visitor to an A/B test variant via cookie, tag Umami pageviews.""" + def decorator(f): @wraps(f) async def wrapper(*args, **kwargs): @@ -448,7 +491,9 @@ def ab_test(experiment: str, variants: tuple = ("control", "treatment")): response = await make_response(await f(*args, **kwargs)) response.set_cookie(cookie_key, assigned, max_age=30 * 24 * 60 * 60) return response + return wrapper + return decorator @@ -456,6 +501,7 @@ def ab_test(experiment: str, variants: tuple = ("control", "treatment")): # Feature Flags (DB-backed, admin-toggleable) # ============================================================================= + async def is_flag_enabled(name: str, default: bool = False) -> bool: """Check if a feature flag is enabled. Falls back to default if not found.""" row = await fetch_one("SELECT enabled FROM feature_flags WHERE name = ?", (name,)) @@ -482,6 +528,7 @@ async def get_all_flags() -> list[dict]: def feature_gate(flag_name: str, fallback_template: str, **extra_context): """Gate a route behind a feature flag; renders fallback on GET, 403 on POST.""" + def decorator(f): @wraps(f) async def decorated(*args, **kwargs): @@ -491,5 +538,7 @@ def feature_gate(flag_name: str, fallback_template: str, **extra_context): return await render_template(fallback_template, **ctx) return {"error": "Feature not available"}, 403 return await f(*args, **kwargs) + return decorated - return decorator \ No newline at end of file + + return decorator