from __future__ import annotations import base64 import hashlib import os import re import sqlite3 from contextlib import contextmanager from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any, Dict, List, Optional from uuid import uuid4 from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field DATABASE_PATH = Path(os.environ.get("ARBEIDSPULS_SHARE_DB", "backend/share-data.sqlite3")) EXPIRY_DAYS = int(os.environ.get("ARBEIDSPULS_SHARE_EXPIRY_DAYS", "7")) MAX_PAYLOAD_BYTES = 2 * 1024 * 1024 MAX_CIPHERTEXT_CHARS = 3_000_000 CREATE_RATE_LIMIT_PER_HOUR = 10 READ_RATE_LIMIT_PER_HOUR = 120 CONFIRM_RATE_LIMIT_PER_HOUR = 60 TRUSTED_PROXY_HOSTS = {"127.0.0.1", "::1", "localhost"} ALLOWED_ORIGINS = [ "https://arbeidspuls.rolfsvaag.no", "http://localhost:5173", "http://127.0.0.1:5173", ] app = FastAPI(title="Arbeidspuls secure share API") app.add_middleware( CORSMiddleware, allow_origins=[origin.strip() for origin in os.environ.get("ARBEIDSPULS_CORS_ORIGINS", ",".join(ALLOWED_ORIGINS)).split(",")], allow_credentials=False, allow_methods=["POST", "GET", "OPTIONS"], allow_headers=["Content-Type"], ) rate_bucket: Dict[str, List[datetime]] = {} class ShareCreateRequest(BaseModel): ciphertext: str = Field(min_length=1, max_length=MAX_CIPHERTEXT_CHARS) iv: str = Field(min_length=1, max_length=64) share_schema: str = Field(alias="schema", min_length=1, max_length=80) confirm_token_hash: str = Field(min_length=32, max_length=128) class ConfirmImportRequest(BaseModel): confirm_token: Optional[str] = Field(default=None, max_length=256) class ShareCreateResponse(BaseModel): export_id: str expires_at: str @contextmanager def connect(): DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True) db = sqlite3.connect(DATABASE_PATH) db.row_factory = sqlite3.Row try: yield db db.commit() finally: db.close() def utc_now() -> datetime: return datetime.now(timezone.utc) def iso(value: datetime) -> str: return value.isoformat().replace("+00:00", "Z") def init_db() -> None: with connect() as db: db.execute( """ CREATE TABLE IF NOT EXISTS active_shares ( export_id TEXT PRIMARY KEY, created_at TEXT NOT NULL, expires_at TEXT NOT NULL, ciphertext TEXT NOT NULL, iv TEXT NOT NULL, schema TEXT NOT NULL ) """ ) db.execute( """ CREATE TABLE IF NOT EXISTS share_audit ( export_id TEXT PRIMARY KEY, created_at TEXT NOT NULL, deleted_at TEXT NOT NULL, delete_reason TEXT NOT NULL ) """ ) columns = {row["name"] for row in db.execute("PRAGMA table_info(active_shares)").fetchall()} if "confirm_token_hash" not in columns: db.execute("ALTER TABLE active_shares ADD COLUMN confirm_token_hash TEXT") @app.middleware("http") async def security_headers(request: Request, call_next): response = await call_next(request) if request.url.path.startswith("/api/share"): response.headers["Cache-Control"] = "no-store" response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" response.headers["X-Content-Type-Options"] = "nosniff" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=(), payment=()" return response def decoded_size(value: str) -> int: if not re.fullmatch(r"[A-Za-z0-9_-]+", value): raise HTTPException(status_code=400, detail="Invalid base64url payload.") padded = value.replace("-", "+").replace("_", "/") + "=" * (-len(value) % 4) try: return len(base64.b64decode(padded, validate=True)) except Exception as exc: raise HTTPException(status_code=400, detail="Invalid base64url payload.") from exc def cleanup_expired() -> None: now = iso(utc_now()) with connect() as db: rows = db.execute("SELECT export_id, created_at FROM active_shares WHERE expires_at <= ?", (now,)).fetchall() for row in rows: mark_deleted(db, row["export_id"], row["created_at"], "expired") def mark_deleted(db: sqlite3.Connection, export_id: str, created_at: str, reason: str) -> None: db.execute("DELETE FROM active_shares WHERE export_id = ?", (export_id,)) db.execute( """ INSERT OR REPLACE INTO share_audit (export_id, created_at, deleted_at, delete_reason) VALUES (?, ?, ?, ?) """, (export_id, created_at, iso(utc_now()), reason), ) def client_ip(request: Request) -> str: direct_host = request.client.host if request.client else "unknown" if direct_host in TRUSTED_PROXY_HOSTS: real_ip = request.headers.get("x-real-ip", "").strip() if real_ip: return real_ip return direct_host def assert_rate_limit(request: Request, action: str, limit: int) -> None: ip = client_ip(request) bucket_key = f"{action}:{ip}" cutoff = utc_now() - timedelta(hours=1) recent = [created for created in rate_bucket.get(bucket_key, []) if created > cutoff] if len(recent) >= limit: raise HTTPException(status_code=429, detail="Rate limit exceeded.") recent.append(utc_now()) rate_bucket[bucket_key] = recent def token_hash(confirm_token: str) -> str: digest = hashlib.sha256(confirm_token.encode("utf-8")).digest() return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") @app.on_event("startup") def startup() -> None: init_db() cleanup_expired() @app.post("/api/share", response_model=ShareCreateResponse) def create_share(payload: ShareCreateRequest, request: Request) -> Dict[str, str]: cleanup_expired() assert_rate_limit(request, "create", CREATE_RATE_LIMIT_PER_HOUR) if decoded_size(payload.ciphertext) > MAX_PAYLOAD_BYTES: raise HTTPException(status_code=413, detail="Payload is too large.") if decoded_size(payload.iv) != 12: raise HTTPException(status_code=400, detail="Invalid IV size.") if not re.fullmatch(r"[A-Za-z0-9_-]+", payload.confirm_token_hash): raise HTTPException(status_code=400, detail="Invalid confirm verifier.") now = utc_now() export_id = str(uuid4()) expires_at = now + timedelta(days=EXPIRY_DAYS) with connect() as db: db.execute( """ INSERT INTO active_shares (export_id, created_at, expires_at, ciphertext, iv, schema, confirm_token_hash) VALUES (?, ?, ?, ?, ?, ?, ?) """, (export_id, iso(now), iso(expires_at), payload.ciphertext, payload.iv, payload.share_schema, payload.confirm_token_hash), ) return {"export_id": export_id, "expires_at": iso(expires_at)} @app.get("/api/share/{export_id}") def get_share(export_id: str, request: Request) -> Dict[str, Any]: cleanup_expired() assert_rate_limit(request, "read", READ_RATE_LIMIT_PER_HOUR) with connect() as db: row = db.execute("SELECT * FROM active_shares WHERE export_id = ?", (export_id,)).fetchone() if row: return { "status": "available", "export_id": row["export_id"], "ciphertext": row["ciphertext"], "iv": row["iv"], "schema": row["schema"], "expires_at": row["expires_at"], } audit = db.execute("SELECT delete_reason FROM share_audit WHERE export_id = ?", (export_id,)).fetchone() if audit and audit["delete_reason"] == "expired": return {"status": "expired", "export_id": export_id} if audit: return {"status": "deleted", "export_id": export_id} raise HTTPException(status_code=404, detail="Share not found.") @app.post("/api/share/{export_id}/confirm-import") def confirm_import(export_id: str, payload: ConfirmImportRequest, request: Request) -> Dict[str, str]: cleanup_expired() assert_rate_limit(request, "confirm", CONFIRM_RATE_LIMIT_PER_HOUR) with connect() as db: row = db.execute("SELECT export_id, created_at, confirm_token_hash FROM active_shares WHERE export_id = ?", (export_id,)).fetchone() if not row: raise HTTPException(status_code=404, detail="Share is not available.") expected = row["confirm_token_hash"] if not expected or not payload.confirm_token or len(payload.confirm_token) < 16 or token_hash(payload.confirm_token) != expected: raise HTTPException(status_code=403, detail="Invalid confirmation proof.") mark_deleted(db, row["export_id"], row["created_at"], "imported_by_recipient") return {"status": "deleted"}