Arbeidspuls/backend/app.py
2026-05-29 17:52:54 +02:00

202 lines
6.6 KiB
Python

from __future__ import annotations
import base64
import os
import re
import sqlite3
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Dict, List
from uuid import uuid4
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
DATABASE_PATH = Path(os.environ.get("ARBEIDSPULS_SHARE_DB", "backend/share-data.sqlite3"))
EXPIRY_DAYS = int(os.environ.get("ARBEIDSPULS_SHARE_EXPIRY_DAYS", "7"))
MAX_PAYLOAD_BYTES = 2 * 1024 * 1024
RATE_LIMIT_PER_HOUR = 10
ALLOWED_ORIGINS = [
"https://arbeidspuls.rolfsvaag.no",
"http://localhost:5173",
"http://127.0.0.1:5173",
]
app = FastAPI(title="Arbeidspuls secure share API")
app.add_middleware(
CORSMiddleware,
allow_origins=[origin.strip() for origin in os.environ.get("ARBEIDSPULS_CORS_ORIGINS", ",".join(ALLOWED_ORIGINS)).split(",")],
allow_credentials=False,
allow_methods=["POST", "GET", "OPTIONS"],
allow_headers=["Content-Type"],
)
rate_bucket: Dict[str, List[datetime]] = {}
class ShareCreateRequest(BaseModel):
ciphertext: str = Field(min_length=1)
iv: str = Field(min_length=1)
share_schema: str = Field(alias="schema", min_length=1, max_length=80)
class ShareCreateResponse(BaseModel):
export_id: str
expires_at: str
@contextmanager
def connect():
DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True)
db = sqlite3.connect(DATABASE_PATH)
db.row_factory = sqlite3.Row
try:
yield db
db.commit()
finally:
db.close()
def utc_now() -> datetime:
return datetime.now(timezone.utc)
def iso(value: datetime) -> str:
return value.isoformat().replace("+00:00", "Z")
def init_db() -> None:
with connect() as db:
db.execute(
"""
CREATE TABLE IF NOT EXISTS active_shares (
export_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
expires_at TEXT NOT NULL,
ciphertext TEXT NOT NULL,
iv TEXT NOT NULL,
schema TEXT NOT NULL
)
"""
)
db.execute(
"""
CREATE TABLE IF NOT EXISTS share_audit (
export_id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
deleted_at TEXT NOT NULL,
delete_reason TEXT NOT NULL
)
"""
)
def decoded_size(value: str) -> int:
if not re.fullmatch(r"[A-Za-z0-9_-]+", value):
raise HTTPException(status_code=400, detail="Invalid base64url payload.")
padded = value.replace("-", "+").replace("_", "/") + "=" * (-len(value) % 4)
try:
return len(base64.b64decode(padded, validate=True))
except Exception as exc:
raise HTTPException(status_code=400, detail="Invalid base64url payload.") from exc
def cleanup_expired() -> None:
now = iso(utc_now())
with connect() as db:
rows = db.execute("SELECT export_id, created_at FROM active_shares WHERE expires_at <= ?", (now,)).fetchall()
for row in rows:
mark_deleted(db, row["export_id"], row["created_at"], "expired")
def mark_deleted(db: sqlite3.Connection, export_id: str, created_at: str, reason: str) -> None:
db.execute("DELETE FROM active_shares WHERE export_id = ?", (export_id,))
db.execute(
"""
INSERT OR REPLACE INTO share_audit (export_id, created_at, deleted_at, delete_reason)
VALUES (?, ?, ?, ?)
""",
(export_id, created_at, iso(utc_now()), reason),
)
def client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for", "")
if forwarded:
return forwarded.split(",")[0].strip()
return request.client.host if request.client else "unknown"
def assert_rate_limit(request: Request) -> None:
ip = client_ip(request)
cutoff = utc_now() - timedelta(hours=1)
recent = [created for created in rate_bucket.get(ip, []) if created > cutoff]
if len(recent) >= RATE_LIMIT_PER_HOUR:
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
recent.append(utc_now())
rate_bucket[ip] = recent
@app.on_event("startup")
def startup() -> None:
init_db()
cleanup_expired()
@app.post("/api/share", response_model=ShareCreateResponse)
def create_share(payload: ShareCreateRequest, request: Request) -> Dict[str, str]:
cleanup_expired()
assert_rate_limit(request)
if decoded_size(payload.ciphertext) > MAX_PAYLOAD_BYTES:
raise HTTPException(status_code=413, detail="Payload is too large.")
if decoded_size(payload.iv) != 12:
raise HTTPException(status_code=400, detail="Invalid IV size.")
now = utc_now()
export_id = str(uuid4())
expires_at = now + timedelta(days=EXPIRY_DAYS)
with connect() as db:
db.execute(
"""
INSERT INTO active_shares (export_id, created_at, expires_at, ciphertext, iv, schema)
VALUES (?, ?, ?, ?, ?, ?)
""",
(export_id, iso(now), iso(expires_at), payload.ciphertext, payload.iv, payload.share_schema),
)
return {"export_id": export_id, "expires_at": iso(expires_at)}
@app.get("/api/share/{export_id}")
def get_share(export_id: str) -> Dict[str, Any]:
cleanup_expired()
with connect() as db:
row = db.execute("SELECT * FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
if row:
return {
"status": "available",
"export_id": row["export_id"],
"ciphertext": row["ciphertext"],
"iv": row["iv"],
"schema": row["schema"],
"expires_at": row["expires_at"],
}
audit = db.execute("SELECT delete_reason FROM share_audit WHERE export_id = ?", (export_id,)).fetchone()
if audit and audit["delete_reason"] == "expired":
return {"status": "expired", "export_id": export_id}
if audit:
return {"status": "deleted", "export_id": export_id}
raise HTTPException(status_code=404, detail="Share not found.")
@app.post("/api/share/{export_id}/confirm-import")
def confirm_import(export_id: str) -> Dict[str, str]:
cleanup_expired()
with connect() as db:
row = db.execute("SELECT export_id, created_at FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
if not row:
raise HTTPException(status_code=404, detail="Share is not available.")
mark_deleted(db, row["export_id"], row["created_at"], "imported_by_recipient")
return {"status": "deleted"}