import asyncio
import uuid
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse

from itsdangerous import URLSafeTimedSerializer, BadSignature, SignatureExpired
from micropie import HttpMiddleware, App, Request


class CSRFMiddleware(HttpMiddleware):
    """
    MicroPie-ready CSRF middleware using itsdangerous + session binding.

    - Verifies on POST/PUT/PATCH/DELETE
    - Accepts token from body (form or JSON) or 'X-CSRF-Token' header
    - For multipart/form-data, strongly prefer header (parser may still be streaming)
    - Token payload includes the session_id (when present) to bind token to that session
    - Emits 'X-CSRF-Token' **only** when we create/rotate one during this request
    - Supports exempt_paths (e.g. webhook endpoints like /sms_order)
    """

    def __init__(
        self,
        app: App,
        secret_key: str,
        *,
        max_age: int = 24 * 3600,  # 24 hours
        trusted_origins: Optional[List[str]] = None,
        body_field: str = "csrf_token",
        header_name: str = "x-csrf-token",
        require_header_for_multipart: bool = True,
        exempt_paths: Optional[List[str]] = None,
    ):
        self.app = app
        self.serializer = URLSafeTimedSerializer(secret_key, salt="csrf-token")
        self.max_age = max_age
        self.trusted = set(trusted_origins or [])
        self.body_field = body_field
        self.header_name = header_name.lower()
        self.require_header_for_multipart = require_header_for_multipart
        self.exempt_paths = set(exempt_paths or [])

    # ---------- helpers ----------

    def _is_mutating(self, method: str) -> bool:
        return method in ("POST", "PUT", "PATCH", "DELETE")

    def _is_multipart(self, ct: str) -> bool:
        return "multipart/form-data" in (ct or "")

    def _origin_ok(self, headers: Dict[str, str]) -> bool:
        if not self.trusted:
            return True
        origin = headers.get("origin")
        referer = headers.get("referer")
        for hdr in (origin, referer):
            if not hdr:
                continue
            try:
                p = urlparse(hdr)
                base = f"{p.scheme}://{p.netloc}"
                if base in self.trusted:
                    return True
            except Exception:
                pass
        return False

    def _get_session_id(self, request: Request) -> Optional[str]:
        cookie = request.headers.get("cookie", "")
        if "session_id=" in cookie:
            return cookie.split("session_id=", 1)[1].split(";", 1)[0].strip() or None
        return None

    def _issue_token(self, session_id: Optional[str]) -> str:
        payload = {"nonce": str(uuid.uuid4())}
        if session_id:
            payload["sid"] = session_id
        return self.serializer.dumps(payload)

    async def _extract_submitted_token(self, request: Request) -> Optional[str]:
        ct = request.headers.get("content-type", "")
        if self._is_multipart(ct) and self.require_header_for_multipart:
            token = request.headers.get(self.header_name)
            if token:
                return token

        lst = request.body_params.get(self.body_field)
        if not lst and self._is_multipart(ct):
            while not request.body_parsed:
                lst = request.body_params.get(self.body_field)
                if lst:
                    break
                await asyncio.sleep(0)

        if lst and isinstance(lst, list) and lst:
            return lst[0]

        j = getattr(request, "get_json", None)
        if isinstance(j, dict):
            tok = j.get(self.body_field)
            if isinstance(tok, str):
                return tok

        return request.headers.get(self.header_name)

    # ---------- middleware hooks ----------

    async def before_request(self, request: Request) -> Optional[Dict]:
        path = request.scope.get("path", "")

        # Exempt specific paths (e.g. /sms_order webhook)
        if path in self.exempt_paths:
            return None

        if "csrf_token" not in request.session:
            sid = self._get_session_id(request)
            request.session["csrf_token"] = self._issue_token(sid)
            setattr(request, "_csrf_emit", request.session["csrf_token"])

        if not self._is_mutating(request.method):
            return None

        if not self._origin_ok(request.headers):
            return {"status_code": 403, "body": "Forbidden: invalid origin/referer"}

        submitted = await self._extract_submitted_token(request)
        if not submitted:
            return {"status_code": 403, "body": "Missing CSRF token"}

        try:
            data = self.serializer.loads(submitted, max_age=self.max_age)
        except SignatureExpired:
            return {
                "status_code": 403,
                "body": "<h1>Expired CSRF token, please reload the page and try again.</h1>",
            }
        except BadSignature:
            return {
                "status_code": 403,
                "body": "<h1>Invalid CSRF token signature, please reload the page and try again.</h1>",
            }

        sid = self._get_session_id(request)
        token_sid = data.get("sid")
        if token_sid is not None and (sid != token_sid):
            return {"status_code": 403, "body": "Invalid CSRF token for this session"}

        sid_now = self._get_session_id(request)
        new_token = self._issue_token(sid_now)
        request.session["csrf_token"] = new_token
        setattr(request, "_csrf_emit", new_token)

        return None

    async def after_request(
        self,
        request: Request,
        status_code: int,
        response_body: Any,
        extra_headers: List[Tuple[str, str]],
    ) -> Optional[Dict]:
        to_emit = getattr(request, "_csrf_emit", None)
        if to_emit:
            extra_headers.append(("X-CSRF-Token", to_emit))
        return None
