from typing import Optional, Dict, List, Tuple, Any
from html import escape
import uuid
from itsdangerous import URLSafeTimedSerializer, BadSignature
from micropie import App, HttpMiddleware, Request
class CSRFMiddleware(HttpMiddleware):
"""Middleware for CSRF protection using itsdangerous-signed tokens."""
def __init__(self, app: App, secret_key: str, max_age: int = 8 * 3600):
self.app = app # Store the App instance
self.serializer = URLSafeTimedSerializer(secret_key, salt="csrf-token")
self.max_age = max_age
async def before_request(self, request: Request) -> Optional[Dict]:
"""Verify CSRF token for POST/PUT/PATCH requests and generate a new token if needed."""
# Extract session ID from cookies or generate a new one
session_id = (
request.headers.get("cookie", "").split("session_id=")[-1].split(";")[0]
if "session_id=" in request.headers.get("cookie", "")
else str(uuid.uuid4())
)
if request.method in ("POST", "PUT", "PATCH"):
print(f"Request body_params: {request.body_params}") # Debugging
submitted_token = request.form("csrf_token", "")
print(f"Submitted CSRF token: {submitted_token}") # Debugging
if not submitted_token:
return {"status_code": 403, "body": "Missing CSRF token"}
try:
# Verify the submitted token's signature
self.serializer.loads(submitted_token, max_age=self.max_age)
except BadSignature:
return {"status_code": 403, "body": "Invalid or expired CSRF token"}
# Generate a new CSRF token if one doesn't exist in the session
if "csrf_token" not in request.session:
csrf_token = str(uuid.uuid4())
signed_token = self.serializer.dumps(csrf_token)
request.session["csrf_token"] = signed_token
# Save the session
print(f"Saving session with CSRF token: {signed_token}") # Debugging
await self.app.session_backend.save(
session_id, request.session, self.max_age
)
return None
async def after_request(
self,
request: Request,
status_code: int,
response_body: Any,
extra_headers: List[Tuple[str, str]],
) -> Optional[Dict]:
"""Include CSRF token in response headers for client-side use."""
if request.session.get("csrf_token"):
extra_headers.append(("X-CSRF-Token", request.session["csrf_token"]))
return None
# Define the Sub-App
class ApiApp(App):
async def index(self):
return {"message": "Welcome to the API!"}
async def users(self, user_id: str):
return {"user_id": user_id, "message": f"User {user_id} from API"}
async def visits(self):
if "visits" not in self.request.session:
self.request.session["visits"] = 1
else:
self.request.session["visits"] += 1
visits = self.request.session["visits"]
return f"You have visited {visits} times."
async def login(self):
csrf_token = self.request.session.get("csrf_token", "")
return f"""<form method="POST" action="/api/plogin">
<input type="hidden" name="csrf_token" value="{escape(csrf_token)}">
<input type="text" name="name">
<button type="submit">Submit</button>
</form>"""
async def plogin(self, name):
return f"Hello {name}"
class UserApp(App):
async def index(self):
return {"msg": "Hello world"}
async def hello(self, name="world"):
return f"hello {name}"
# Define a Middleware to Mount the Sub-App
class SubAppMiddleware(HttpMiddleware):
def __init__(self, mount_path: str, subapp: App):
self.mount_path = mount_path.lstrip("/")
self.subapp = subapp
async def before_request(self, request):
path = request.scope["path"].lstrip("/")
if path.startswith(self.mount_path):
# Set the subapp and the remaining path
request._subapp = self.subapp
request._subapp_path = path[len(self.mount_path) :].lstrip("/") or "/"
return None # Continue processing
return None # Not a subapp path, continue with main app
async def after_request(self, request, status_code, response_body, extra_headers):
return None # No changes to response
# Define the Main App
class MainApp(App):
async def index(self):
return {"message": "Welcome to the Main App!"}
async def hello(self, name: str):
return {"message": f"Hello, {name} from Main App!"}
# Create and Configure the Apps
app = MainApp()
api_app = ApiApp()
api_app.middlewares.append(CSRFMiddleware(app=app, secret_key="my-secret-key"))
user_app = UserApp()
apiapp_middleware = SubAppMiddleware(mount_path="/api", subapp=api_app)
userapp_middleware = SubAppMiddleware(mount_path="/user", subapp=user_app)
app.middlewares.append(apiapp_middleware)
app.middlewares.append(userapp_middleware)