Arbeidspuls/backend/app.py
2026-05-29 18:27:05 +02:00

243 lines
8.8 KiB
Python

from __future__ import annotations
import base64
import hashlib
import os
import re
import sqlite3
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
DATABASE_PATH = Path(os.environ.get("ARBEIDSPULS_SHARE_DB", "backend/share-data.sqlite3"))
EXPIRY_DAYS = int(os.environ.get("ARBEIDSPULS_SHARE_EXPIRY_DAYS", "7"))
MAX_PAYLOAD_BYTES = 2 * 1024 * 1024
MAX_CIPHERTEXT_CHARS = 3_000_000
CREATE_RATE_LIMIT_PER_HOUR = 10
READ_RATE_LIMIT_PER_HOUR = 120
CONFIRM_RATE_LIMIT_PER_HOUR = 60
TRUSTED_PROXY_HOSTS = {"127.0.0.1", "::1", "localhost"}
ALLOWED_ORIGINS = [
"https://arbeidspuls.rolfsvaag.no",
"http://localhost:5173",
"http://127.0.0.1:5173",
]
app = FastAPI(title="Arbeidspuls secure share API")
app.add_middleware(
CORSMiddleware,
allow_origins=[origin.strip() for origin in os.environ.get("ARBEIDSPULS_CORS_ORIGINS", ",".join(ALLOWED_ORIGINS)).split(",")],
allow_credentials=False,
allow_methods=["POST", "GET", "OPTIONS"],
allow_headers=["Content-Type"],
)
rate_bucket: Dict[str, List[datetime]] = {}
class ShareCreateRequest(BaseModel):
ciphertext: str = Field(min_length=1, max_length=MAX_CIPHERTEXT_CHARS)
iv: str = Field(min_length=1, max_length=64)
share_schema: str = Field(alias="schema", min_length=1, max_length=80)
confirm_token_hash: str = Field(min_length=32, max_length=128)
class ConfirmImportRequest(BaseModel):
confirm_token: Optional[str] = Field(default=None, max_length=256)
class ShareCreateResponse(BaseModel):
export_id: str
expires_at: str
@contextmanager
def connect():
DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True)
db = sqlite3.connect(DATABASE_PATH)
db.row_factory = sqlite3.Row
try:
yield db
db.commit()
finally:
db.close()
def utc_now() -> datetime:
return datetime.now(timezone.utc)
def iso(value: datetime) -> str:
return value.isoformat().replace("+00:00", "Z")
def init_db() -> None:
with connect() as db:
db.execute(
"""
CREATE TABLE IF NOT EXISTS active_shares (
export_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
ciphertext TEXT NOT NULL,
iv TEXT NOT NULL,
schema TEXT NOT NULL
)
"""
)
db.execute(
"""
CREATE TABLE IF NOT EXISTS share_audit (
export_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
deleted_at TEXT NOT NULL,
delete_reason TEXT NOT NULL
)
"""
)
columns = {row["name"] for row in db.execute("PRAGMA table_info(active_shares)").fetchall()}
if "confirm_token_hash" not in columns:
db.execute("ALTER TABLE active_shares ADD COLUMN confirm_token_hash TEXT")
@app.middleware("http")
async def security_headers(request: Request, call_next):
response = await call_next(request)
if request.url.path.startswith("/api/share"):
response.headers["Cache-Control"] = "no-store"
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=(), payment=()"
return response
def decoded_size(value: str) -> int:
if not re.fullmatch(r"[A-Za-z0-9_-]+", value):
raise HTTPException(status_code=400, detail="Invalid base64url payload.")
padded = value.replace("-", "+").replace("_", "/") + "=" * (-len(value) % 4)
try:
return len(base64.b64decode(padded, validate=True))
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid base64url payload.") from exc
def cleanup_expired() -> None:
now = iso(utc_now())
with connect() as db:
rows = db.execute("SELECT export_id, created_at FROM active_shares WHERE expires_at <= ?", (now,)).fetchall()
for row in rows:
mark_deleted(db, row["export_id"], row["created_at"], "expired")
def mark_deleted(db: sqlite3.Connection, export_id: str, created_at: str, reason: str) -> None:
db.execute("DELETE FROM active_shares WHERE export_id = ?", (export_id,))
db.execute(
"""
INSERT OR REPLACE INTO share_audit (export_id, created_at, deleted_at, delete_reason)
VALUES (?, ?, ?, ?)
""",
(export_id, created_at, iso(utc_now()), reason),
)
def client_ip(request: Request) -> str:
direct_host = request.client.host if request.client else "unknown"
if direct_host in TRUSTED_PROXY_HOSTS:
real_ip = request.headers.get("x-real-ip", "").strip()
if real_ip:
return real_ip
return direct_host
def assert_rate_limit(request: Request, action: str, limit: int) -> None:
ip = client_ip(request)
bucket_key = f"{action}:{ip}"
cutoff = utc_now() - timedelta(hours=1)
recent = [created for created in rate_bucket.get(bucket_key, []) if created > cutoff]
if len(recent) >= limit:
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
recent.append(utc_now())
rate_bucket[bucket_key] = recent
def token_hash(confirm_token: str) -> str:
digest = hashlib.sha256(confirm_token.encode("utf-8")).digest()
return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")
@app.on_event("startup")
def startup() -> None:
init_db()
cleanup_expired()
@app.post("/api/share", response_model=ShareCreateResponse)
def create_share(payload: ShareCreateRequest, request: Request) -> Dict[str, str]:
cleanup_expired()
assert_rate_limit(request, "create", CREATE_RATE_LIMIT_PER_HOUR)
if decoded_size(payload.ciphertext) > MAX_PAYLOAD_BYTES:
raise HTTPException(status_code=413, detail="Payload is too large.")
if decoded_size(payload.iv) != 12:
raise HTTPException(status_code=400, detail="Invalid IV size.")
if not re.fullmatch(r"[A-Za-z0-9_-]+", payload.confirm_token_hash):
raise HTTPException(status_code=400, detail="Invalid confirm verifier.")
now = utc_now()
export_id = str(uuid4())
expires_at = now + timedelta(days=EXPIRY_DAYS)
with connect() as db:
db.execute(
"""
INSERT INTO active_shares (export_id, created_at, expires_at, ciphertext, iv, schema, confirm_token_hash)
VALUES (?, ?, ?, ?, ?, ?, ?)
""",
(export_id, iso(now), iso(expires_at), payload.ciphertext, payload.iv, payload.share_schema, payload.confirm_token_hash),
)
return {"export_id": export_id, "expires_at": iso(expires_at)}
@app.get("/api/share/{export_id}")
def get_share(export_id: str, request: Request) -> Dict[str, Any]:
cleanup_expired()
assert_rate_limit(request, "read", READ_RATE_LIMIT_PER_HOUR)
with connect() as db:
row = db.execute("SELECT * FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
if row:
return {
"status": "available",
"export_id": row["export_id"],
"ciphertext": row["ciphertext"],
"iv": row["iv"],
"schema": row["schema"],
"expires_at": row["expires_at"],
}
audit = db.execute("SELECT delete_reason FROM share_audit WHERE export_id = ?", (export_id,)).fetchone()
if audit and audit["delete_reason"] == "expired":
return {"status": "expired", "export_id": export_id}
if audit:
return {"status": "deleted", "export_id": export_id}
raise HTTPException(status_code=404, detail="Share not found.")
@app.post("/api/share/{export_id}/confirm-import")
def confirm_import(export_id: str, payload: ConfirmImportRequest, request: Request) -> Dict[str, str]:
cleanup_expired()
assert_rate_limit(request, "confirm", CONFIRM_RATE_LIMIT_PER_HOUR)
with connect() as db:
row = db.execute("SELECT export_id, created_at, confirm_token_hash FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Share is not available.")
expected = row["confirm_token_hash"]
if not expected or not payload.confirm_token or len(payload.confirm_token) < 16 or token_hash(payload.confirm_token) != expected:
raise HTTPException(status_code=403, detail="Invalid confirmation proof.")
mark_deleted(db, row["export_id"], row["created_at"], "imported_by_recipient")
return {"status": "deleted"}