FastAPI Middleware
Overview
Middleware is an important mechanism for handling cross-cutting concerns in web applications, such as logging, performance monitoring, security checks, CORS processing, etc. Based on Starlette, FastAPI provides a powerful middleware system. This chapter will详细介绍 how to use and create middleware.
🔧 Middleware Basic Concepts
Middleware Execution Flow
graph TB
A[Client Request] --> B[Middleware 1 - Request Processing]
B --> C[Middleware 2 - Request Processing]
C --> D[Route Handler Function]
D --> E[Middleware 2 - Response Processing]
E --> F[Middleware 1 - Response Processing]
F --> G[Client Response]Middleware adopts "onion model", executing request processing in the order they are added, then response processing in reverse order.
Basic Middleware Example
from fastapi import FastAPI, Request, Response
import time
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
# Request processing
logger.info(f"Request start: {request.method} {request.url}")
# Call next middleware or route handler
response = await call_next(request)
# Response processing
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"Request complete: {request.method} {request.url} - {response.status_code} - {process_time:.4f}s")
return response
@app.get("/")
async def read_root():
return {"message": "Hello World"}
@app.get("/slow")
async def slow_endpoint():
await asyncio.sleep(2) # Simulate slow operation
return {"message": "This was slow"}🛡️ Built-in Middleware
CORS Middleware
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "https://myapp.com"], # Allowed origins
allow_credentials=True, # Allow credentials
allow_methods=["GET", "POST", "PUT", "DELETE"], # Allowed methods
allow_headers=["*"], # Allowed headers
expose_headers=["X-Custom-Header"], # Exposed headers
max_age=3600, # Preflight request cache time
)
# Development environment configuration (allow all origins)
if os.getenv("ENVIRONMENT") == "development":
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)HTTPS Redirect Middleware
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
# Production environment force HTTPS
if os.getenv("ENVIRONMENT") == "production":
app.add_middleware(HTTPSRedirectMiddleware)Trusted Host Middleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost", "127.0.0.1"]
)GZip Compression Middleware
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)🔨 Custom Middleware
Request ID Middleware
import uuid
from contextvars import ContextVar
# Use ContextVar to store request ID
request_id_contextvar: ContextVar[str] = ContextVar('request_id', default="")
class RequestIDMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Generate request ID
request_id = str(uuid.uuid4())
request_id_contextvar.set(request_id)
# Add to response headers
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.append([b"x-request-id", request_id.encode()])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
app.add_middleware(RequestIDMiddleware)
def get_request_id() -> str:
"""Get current request's ID"""
return request_id_contextvar.get()
@app.get("/test")
async def test_endpoint():
return {"request_id": get_request_id(), "message": "Test endpoint"}Performance Monitoring Middleware
import psutil
from typing import Dict, Any
class PerformanceMiddleware:
def __init__(self, app):
self.app = app
self.stats = {
"total_requests": 0,
"total_time": 0,
"avg_response_time": 0,
"request_count_by_method": {},
"request_count_by_status": {}
}
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request_method = scope["method"]
status_code = None
async def send_wrapper(message):
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message["status"]
# Add performance headers
headers = list(message.get("headers", []))
# System resource usage
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
headers.extend([
[b"x-cpu-usage", f"{cpu_percent:.1f}%".encode()],
[b"x-memory-usage", f"{memory_percent:.1f}%".encode()],
[b"x-request-count", str(self.stats["total_requests"]).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
# Update statistics
end_time = time.time()
response_time = end_time - start_time
self.stats["total_requests"] += 1
self.stats["total_time"] += response_time
self.stats["avg_response_time"] = self.stats["total_time"] / self.stats["total_requests"]
# Statistics by method
self.stats["request_count_by_method"][request_method] = \
self.stats["request_count_by_method"].get(request_method, 0) + 1
# Statistics by status code
if status_code:
self.stats["request_count_by_status"][status_code] = \
self.stats["request_count_by_status"].get(status_code, 0) + 1
performance_middleware = PerformanceMiddleware(app)
app.add_middleware(lambda: performance_middleware)
@app.get("/stats")
async def get_performance_stats():
return performance_middleware.statsRate Limiting Middleware
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
class RateLimitMiddleware:
def __init__(self, app, calls: int = 100, period: int = 60):
self.app = app
self.calls = calls # Allowed number of calls
self.period = period # Time window (seconds)
self.clients = defaultdict(list) # Store client request times
self.cleanup_task = None
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Get client IP
client_ip = self._get_client_ip(scope)
current_time = datetime.now()
# Cleanup old requests
self._cleanup_old_requests(client_ip, current_time)
# Check rate limit
if len(self.clients[client_ip]) >= self.calls:
# Exceeds limit, return 429
response = {
"type": "http.response.start",
"status": 429,
"headers": [
[b"content-type", b"application/json"],
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", b"0"],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
]
}
await send(response)
body = {
"type": "http.response.body",
"body": b'{"error": "Rate limit exceeded", "message": "Too many requests"}'
}
await send(body)
return
# Record request time
self.clients[client_ip].append(current_time)
# Add rate limit headers
remaining = self.calls - len(self.clients[client_ip])
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.extend([
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", str(remaining).encode()],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
def _get_client_ip(self, scope):
# Try to get real IP from headers
headers = dict(scope.get("headers", []))
# Check common proxy headers
for header in [b"x-forwarded-for", b"x-real-ip", b"cf-connecting-ip"]:
if header in headers:
ip = headers[header].decode().split(",")[0].strip()
if ip:
return ip
# Use directly connected IP
client = scope.get("client")
return client[0] if client else "unknown"
def _cleanup_old_requests(self, client_ip, current_time):
cutoff_time = current_time - timedelta(seconds=self.period)
self.clients[client_ip] = [
req_time for req_time in self.clients[client_ip]
if req_time > cutoff_time
]
# If client has no active requests, delete record
if not self.clients[client_ip]:
del self.clients[client_ip]
# Apply rate limiting middleware: maximum 100 requests per minute
app.add_middleware(lambda: RateLimitMiddleware(app, calls=100, period=60))Authentication Middleware
import jwt
from fastapi import HTTPException, status
class AuthenticationMiddleware:
def __init__(self, app, secret_key: str, excluded_paths: list = None):
self.app = app
self.secret_key = secret_key
self.excluded_paths = excluded_paths or ["/", "/docs", "/redoc", "/openapi.json", "/login"]
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope["path"]
method = scope["method"]
# Check if it's an excluded path
if self._is_excluded_path(path):
await self.app(scope, receive, send)
return
# Get Authorization header
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization")
if not auth_header:
await self._send_unauthorized(send, "Missing authorization header")
return
try:
# Parse Bearer token
auth_value = auth_header.decode()
if not auth_value.startswith("Bearer "):
await self._send_unauthorized(send, "Invalid authorization format")
return
token = auth_value[7:] # Remove "Bearer " prefix
# Validate JWT token
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
# Add user info to scope
scope["user"] = payload
except jwt.ExpiredSignatureError:
await self._send_unauthorized(send, "Token has expired")
return
except jwt.InvalidTokenError:
await self._send_unauthorized(send, "Invalid token")
return
await self.app(scope, receive, send)
def _is_excluded_path(self, path: str) -> bool:
return any(path.startswith(excluded) for excluded in self.excluded_paths)
async def _send_unauthorized(self, send, message: str):
response = {
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]]
}
await send(response)
body = {
"type": "http.response.body",
"body": f'{{"error": "Unauthorized", "message": "{message}"}}'.encode()
}
await send(body)
# Apply authentication middleware
SECRET_KEY = "your-secret-key"
app.add_middleware(lambda: AuthenticationMiddleware(app, SECRET_KEY))
# Login endpoint (no authentication required)
@app.post("/login")
async def login(username: str, password: str):
# Simple user authentication
if username == "admin" and password == "secret":
payload = {
"user_id": 1,
"username": username,
"exp": datetime.utcnow() + timedelta(hours=24)
}
token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
return {"access_token": token, "token_type": "bearer"}
else:
raise HTTPException(status_code=401, detail="Invalid credentials")
# Protected endpoint
@app.get("/protected")
async def protected_endpoint(request: Request):
user = request.scope.get("user")
return {"message": "This is protected", "user": user}📊 Error Handling Middleware
Global Exception Handling Middleware
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse
class ErrorHandlingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
try:
await self.app(scope, receive, send)
except Exception as exc:
request = Request(scope, receive)
await self._handle_error(exc, request, send)
async def _handle_error(self, exc: Exception, request: Request, send):
# Log error
logger.error(f"Unhandled exception: {exc}", exc_info=True)
# Determine response based on exception type
if isinstance(exc, HTTPException):
status_code = exc.status_code
detail = exc.detail
elif isinstance(exc, ValueError):
status_code = 400
detail = "Invalid input data"
elif isinstance(exc, FileNotFoundError):
status_code = 404
detail = "Resource not found"
else:
status_code = 500
detail = "Internal server error"
# Build error response
error_response = {
"error": {
"type": exc.__class__.__name__,
"message": detail,
"path": str(request.url.path),
"method": request.method,
"timestamp": datetime.now().isoformat()
}
}
# Development environment includes detailed error information
if os.getenv("ENVIRONMENT") == "development":
error_response["error"]["traceback"] = traceback.format_exc()
response = JSONResponse(
content=error_response,
status_code=status_code
)
await response(scope, receive, send)
app.add_middleware(ErrorHandlingMiddleware)🔧 Middleware Best Practices
Middleware Factory
def create_logging_middleware(log_level: str = "INFO"):
"""Create logging middleware factory function"""
def logging_middleware(app):
async def middleware(scope, receive, send):
if scope["type"] == "http":
start_time = time.time()
request = Request(scope, receive)
logger.log(
getattr(logging, log_level),
f"Request: {request.method} {request.url}"
)
async def send_wrapper(message):
if message["type"] == "http.response.start":
process_time = time.time() - start_time
logger.log(
getattr(logging, log_level),
f"Response: {message['status']} - {process_time:.4f}s"
)
await send(message)
await app(scope, receive, send_wrapper)
else:
await app(scope, receive, send)
return middleware
return logging_middleware
# Use factory to create middleware
app.add_middleware(create_logging_middleware("DEBUG"))Conditional Middleware
def conditional_middleware(condition_func):
"""Conditional middleware decorator"""
def decorator(middleware_class):
def wrapper(app):
if condition_func():
return middleware_class(app)
else:
# If condition not met, return passthrough middleware
async def passthrough(scope, receive, send):
await app(scope, receive, send)
return passthrough
return wrapper
return decorator
@conditional_middleware(lambda: os.getenv("ENABLE_PROFILING") == "true")
class ProfilingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
# Performance analysis logic
import cProfile
profiler = cProfile.Profile()
profiler.enable()
await self.app(scope, receive, send)
profiler.disable()
# Save or output analysis results
profiler.dump_stats(f"/tmp/profile_{time.time()}.prof")
app.add_middleware(ProfilingMiddleware)Middleware Configuration
from pydantic import BaseSettings
class MiddlewareSettings(BaseSettings):
enable_cors: bool = True
cors_origins: list = ["*"]
enable_gzip: bool = True
gzip_minimum_size: int = 1000
enable_rate_limiting: bool = False
rate_limit_calls: int = 100
rate_limit_period: int = 60
class Config:
env_file = ".env"
settings = MiddlewareSettings()
def setup_middleware(app: FastAPI):
"""Configure all middleware"""
# CORS
if settings.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# GZip compression
if settings.enable_gzip:
app.add_middleware(
GZipMiddleware,
minimum_size=settings.gzip_minimum_size
)
# Rate limiting
if settings.enable_rate_limiting:
app.add_middleware(
lambda: RateLimitMiddleware(
app,
calls=settings.rate_limit_calls,
period=settings.rate_limit_period
)
)
# Always add middleware
app.add_middleware(RequestIDMiddleware)
app.add_middleware(ErrorHandlingMiddleware)
# Setup middleware
setup_middleware(app)Summary
This chapter detailed FastAPI's middleware system:
- ✅ Middleware Basics: Execution flow, basic concepts
- ✅ Built-in Middleware: CORS, HTTPS redirect, GZip, etc.
- ✅ Custom Middleware: Request ID, performance monitoring, rate limiting, authentication
- ✅ Error Handling: Global exception handling middleware
- ✅ Best Practices: Middleware factories, conditional middleware, configuration management
Middleware is a powerful tool for handling cross-cutting concerns. Reasonable use can greatly enhance application functionality and maintainability.
Middleware Design Recommendations
- Keep middleware responsibilities single
- Consider middleware execution order
- Provide configuration switches to enable/disable middleware
- Pay attention to performance impact, avoid blocking operations
- Do error handling and logging well
- Write testable middleware code
In the next chapter, we will learn FastAPI's dependency injection system and understand how to manage application dependencies.