243 lines
8.8 KiB
Python
243 lines
8.8 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import hashlib
|
|
import os
|
|
import re
|
|
import sqlite3
|
|
from contextlib import contextmanager
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
|
|
DATABASE_PATH = Path(os.environ.get("ARBEIDSPULS_SHARE_DB", "backend/share-data.sqlite3"))
|
|
EXPIRY_DAYS = int(os.environ.get("ARBEIDSPULS_SHARE_EXPIRY_DAYS", "7"))
|
|
MAX_PAYLOAD_BYTES = 2 * 1024 * 1024
|
|
MAX_CIPHERTEXT_CHARS = 3_000_000
|
|
CREATE_RATE_LIMIT_PER_HOUR = 10
|
|
READ_RATE_LIMIT_PER_HOUR = 120
|
|
CONFIRM_RATE_LIMIT_PER_HOUR = 60
|
|
TRUSTED_PROXY_HOSTS = {"127.0.0.1", "::1", "localhost"}
|
|
ALLOWED_ORIGINS = [
|
|
"https://arbeidspuls.rolfsvaag.no",
|
|
"http://localhost:5173",
|
|
"http://127.0.0.1:5173",
|
|
]
|
|
|
|
app = FastAPI(title="Arbeidspuls secure share API")
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=[origin.strip() for origin in os.environ.get("ARBEIDSPULS_CORS_ORIGINS", ",".join(ALLOWED_ORIGINS)).split(",")],
|
|
allow_credentials=False,
|
|
allow_methods=["POST", "GET", "OPTIONS"],
|
|
allow_headers=["Content-Type"],
|
|
)
|
|
|
|
rate_bucket: Dict[str, List[datetime]] = {}
|
|
|
|
|
|
class ShareCreateRequest(BaseModel):
|
|
ciphertext: str = Field(min_length=1, max_length=MAX_CIPHERTEXT_CHARS)
|
|
iv: str = Field(min_length=1, max_length=64)
|
|
share_schema: str = Field(alias="schema", min_length=1, max_length=80)
|
|
confirm_token_hash: str = Field(min_length=32, max_length=128)
|
|
|
|
|
|
class ConfirmImportRequest(BaseModel):
|
|
confirm_token: Optional[str] = Field(default=None, max_length=256)
|
|
|
|
|
|
class ShareCreateResponse(BaseModel):
|
|
export_id: str
|
|
expires_at: str
|
|
|
|
|
|
@contextmanager
|
|
def connect():
|
|
DATABASE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
|
db = sqlite3.connect(DATABASE_PATH)
|
|
db.row_factory = sqlite3.Row
|
|
try:
|
|
yield db
|
|
db.commit()
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
def utc_now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def iso(value: datetime) -> str:
|
|
return value.isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def init_db() -> None:
|
|
with connect() as db:
|
|
db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS active_shares (
|
|
export_id TEXT PRIMARY KEY,
|
|
created_at TEXT NOT NULL,
|
|
expires_at TEXT NOT NULL,
|
|
ciphertext TEXT NOT NULL,
|
|
iv TEXT NOT NULL,
|
|
schema TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
db.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS share_audit (
|
|
export_id TEXT PRIMARY KEY,
|
|
created_at TEXT NOT NULL,
|
|
deleted_at TEXT NOT NULL,
|
|
delete_reason TEXT NOT NULL
|
|
)
|
|
"""
|
|
)
|
|
columns = {row["name"] for row in db.execute("PRAGMA table_info(active_shares)").fetchall()}
|
|
if "confirm_token_hash" not in columns:
|
|
db.execute("ALTER TABLE active_shares ADD COLUMN confirm_token_hash TEXT")
|
|
|
|
|
|
@app.middleware("http")
|
|
async def security_headers(request: Request, call_next):
|
|
response = await call_next(request)
|
|
if request.url.path.startswith("/api/share"):
|
|
response.headers["Cache-Control"] = "no-store"
|
|
response.headers["Pragma"] = "no-cache"
|
|
response.headers["Expires"] = "0"
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
response.headers["Permissions-Policy"] = "camera=(), microphone=(), geolocation=(), payment=()"
|
|
return response
|
|
|
|
|
|
def decoded_size(value: str) -> int:
|
|
if not re.fullmatch(r"[A-Za-z0-9_-]+", value):
|
|
raise HTTPException(status_code=400, detail="Invalid base64url payload.")
|
|
padded = value.replace("-", "+").replace("_", "/") + "=" * (-len(value) % 4)
|
|
try:
|
|
return len(base64.b64decode(padded, validate=True))
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=400, detail="Invalid base64url payload.") from exc
|
|
|
|
|
|
def cleanup_expired() -> None:
|
|
now = iso(utc_now())
|
|
with connect() as db:
|
|
rows = db.execute("SELECT export_id, created_at FROM active_shares WHERE expires_at <= ?", (now,)).fetchall()
|
|
for row in rows:
|
|
mark_deleted(db, row["export_id"], row["created_at"], "expired")
|
|
|
|
|
|
def mark_deleted(db: sqlite3.Connection, export_id: str, created_at: str, reason: str) -> None:
|
|
db.execute("DELETE FROM active_shares WHERE export_id = ?", (export_id,))
|
|
db.execute(
|
|
"""
|
|
INSERT OR REPLACE INTO share_audit (export_id, created_at, deleted_at, delete_reason)
|
|
VALUES (?, ?, ?, ?)
|
|
""",
|
|
(export_id, created_at, iso(utc_now()), reason),
|
|
)
|
|
|
|
|
|
def client_ip(request: Request) -> str:
|
|
direct_host = request.client.host if request.client else "unknown"
|
|
if direct_host in TRUSTED_PROXY_HOSTS:
|
|
real_ip = request.headers.get("x-real-ip", "").strip()
|
|
if real_ip:
|
|
return real_ip
|
|
return direct_host
|
|
|
|
|
|
def assert_rate_limit(request: Request, action: str, limit: int) -> None:
|
|
ip = client_ip(request)
|
|
bucket_key = f"{action}:{ip}"
|
|
cutoff = utc_now() - timedelta(hours=1)
|
|
recent = [created for created in rate_bucket.get(bucket_key, []) if created > cutoff]
|
|
if len(recent) >= limit:
|
|
raise HTTPException(status_code=429, detail="Rate limit exceeded.")
|
|
recent.append(utc_now())
|
|
rate_bucket[bucket_key] = recent
|
|
|
|
|
|
def token_hash(confirm_token: str) -> str:
|
|
digest = hashlib.sha256(confirm_token.encode("utf-8")).digest()
|
|
return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")
|
|
|
|
|
|
@app.on_event("startup")
|
|
def startup() -> None:
|
|
init_db()
|
|
cleanup_expired()
|
|
|
|
|
|
@app.post("/api/share", response_model=ShareCreateResponse)
|
|
def create_share(payload: ShareCreateRequest, request: Request) -> Dict[str, str]:
|
|
cleanup_expired()
|
|
assert_rate_limit(request, "create", CREATE_RATE_LIMIT_PER_HOUR)
|
|
if decoded_size(payload.ciphertext) > MAX_PAYLOAD_BYTES:
|
|
raise HTTPException(status_code=413, detail="Payload is too large.")
|
|
if decoded_size(payload.iv) != 12:
|
|
raise HTTPException(status_code=400, detail="Invalid IV size.")
|
|
if not re.fullmatch(r"[A-Za-z0-9_-]+", payload.confirm_token_hash):
|
|
raise HTTPException(status_code=400, detail="Invalid confirm verifier.")
|
|
|
|
now = utc_now()
|
|
export_id = str(uuid4())
|
|
expires_at = now + timedelta(days=EXPIRY_DAYS)
|
|
with connect() as db:
|
|
db.execute(
|
|
"""
|
|
INSERT INTO active_shares (export_id, created_at, expires_at, ciphertext, iv, schema, confirm_token_hash)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(export_id, iso(now), iso(expires_at), payload.ciphertext, payload.iv, payload.share_schema, payload.confirm_token_hash),
|
|
)
|
|
return {"export_id": export_id, "expires_at": iso(expires_at)}
|
|
|
|
|
|
@app.get("/api/share/{export_id}")
|
|
def get_share(export_id: str, request: Request) -> Dict[str, Any]:
|
|
cleanup_expired()
|
|
assert_rate_limit(request, "read", READ_RATE_LIMIT_PER_HOUR)
|
|
with connect() as db:
|
|
row = db.execute("SELECT * FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
|
|
if row:
|
|
return {
|
|
"status": "available",
|
|
"export_id": row["export_id"],
|
|
"ciphertext": row["ciphertext"],
|
|
"iv": row["iv"],
|
|
"schema": row["schema"],
|
|
"expires_at": row["expires_at"],
|
|
}
|
|
audit = db.execute("SELECT delete_reason FROM share_audit WHERE export_id = ?", (export_id,)).fetchone()
|
|
if audit and audit["delete_reason"] == "expired":
|
|
return {"status": "expired", "export_id": export_id}
|
|
if audit:
|
|
return {"status": "deleted", "export_id": export_id}
|
|
raise HTTPException(status_code=404, detail="Share not found.")
|
|
|
|
|
|
@app.post("/api/share/{export_id}/confirm-import")
|
|
def confirm_import(export_id: str, payload: ConfirmImportRequest, request: Request) -> Dict[str, str]:
|
|
cleanup_expired()
|
|
assert_rate_limit(request, "confirm", CONFIRM_RATE_LIMIT_PER_HOUR)
|
|
with connect() as db:
|
|
row = db.execute("SELECT export_id, created_at, confirm_token_hash FROM active_shares WHERE export_id = ?", (export_id,)).fetchone()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail="Share is not available.")
|
|
expected = row["confirm_token_hash"]
|
|
if not expected or not payload.confirm_token or len(payload.confirm_token) < 16 or token_hash(payload.confirm_token) != expected:
|
|
raise HTTPException(status_code=403, detail="Invalid confirmation proof.")
|
|
mark_deleted(db, row["export_id"], row["created_at"], "imported_by_recipient")
|
|
return {"status": "deleted"}
|