from html import escape
import os
import uuid
from typing import Optional, Dict, List, Tuple, Any
from itsdangerous import URLSafeTimedSerializer, BadSignature
from micropie import App, HttpMiddleware, Request, SESSION_TIMEOUT


class SignedSessionMiddleware(HttpMiddleware):
    """Middleware to sign and verify session cookies using itsdangerous."""

    def __init__(self, app: App, secret_key: str, max_age: int = SESSION_TIMEOUT):
        self.app = app  # Store the App instance
        self.serializer = URLSafeTimedSerializer(secret_key)
        self.max_age = max_age

    async def before_request(self, request: Request) -> Optional[Dict]:
        """Verify the session_id cookie before processing the request."""
        cookies = self.app._parse_cookies(request.headers.get("cookie", ""))
        session_id = cookies.get("session_id", "")
        try:
            verified_id = self.serializer.loads(session_id, max_age=self.max_age)
            request.session = await self.app.session_backend.load(verified_id) or {}
        except BadSignature:
            request.session = {}  # Invalid or expired session_id
        return None

    async def after_request(
        self,
        request: Request,
        status_code: int,
        response_body: Any,
        extra_headers: List[Tuple[str, str]],
    ) -> Optional[Dict]:
        """Sign and set the session_id cookie after processing the request."""
        if request.session:
            cookies = self.app._parse_cookies(request.headers.get("cookie", ""))
            session_id = cookies.get("session_id", "") or str(uuid.uuid4())
            try:
                session_id = self.serializer.loads(session_id, max_age=self.max_age)
            except BadSignature:
                session_id = str(uuid.uuid4())
            signed_session_id = self.serializer.dumps(session_id)
            current_session = await self.app.session_backend.load(session_id) or {}
            if current_session != request.session:
                await self.app.session_backend.save(
                    session_id, request.session, SESSION_TIMEOUT
                )
            if not cookies.get("session_id"):
                extra_headers.append(
                    (
                        "Set-Cookie",
                        f"session_id={signed_session_id}; Path=/; SameSite=Lax; HttpOnly; Secure;",
                    )
                )
        return None


class Root(App):
    async def index(self):
        if "visits" not in self.request.session:
            self.request.session["visits"] = 1
        else:
            self.request.session["visits"] += 1
        visits = self.request.session["visits"]
        return f"You have visited {escape(str(visits))} times."


app = Root()
app.middlewares.append(SignedSessionMiddleware(app=app, secret_key="my-secret-key"))
