#FastAPI Middleware
#Overview
Middleware is an important mechanism for handling cross-cutting concerns in web applications, such as logging, performance monitoring, security checks, CORS processing, etc. Based on Starlette, FastAPI provides a powerful middleware system. This chapter will详细介绍 how to use and create middleware.
#🔧 Middleware Basic Concepts
#Middleware Execution Flow
graph TB
A[Client Request] --> B[Middleware 1 - Request Processing]
B --> C[Middleware 2 - Request Processing]
C --> D[Route Handler Function]
D --> E[Middleware 2 - Response Processing]
E --> F[Middleware 1 - Response Processing]
F --> G[Client Response]Middleware adopts "onion model", executing request processing in the order they are added, then response processing in reverse order.
#Basic Middleware Example
from fastapi import FastAPI, Request, Response
import time
import logging
app = FastAPI()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def log_requests(request: Request, call_next):
start_time = time.time()
# Request processing
logger.info(f"Request start: {request.method} {request.url}")
# Call next middleware or route handler
response = await call_next(request)
# Response processing
process_time = time.time() - start_time
response.headers["X-Process-Time"] = str(process_time)
logger.info(f"Request complete: {request.method} {request.url} - {response.status_code} - {process_time:.4f}s")
return response
@app.get("/")
async def read_root():
return {"message": "Hello World"}
@app.get("/slow")
async def slow_endpoint():
await asyncio.sleep(2) # Simulate slow operation
return {"message": "This was slow"}#🛡️ Built-in Middleware
#CORS Middleware
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:3000", "https://myapp.com"], # Allowed origins
allow_credentials=True, # Allow credentials
allow_methods=["GET", "POST", "PUT", "DELETE"], # Allowed methods
allow_headers=["*"], # Allowed headers
expose_headers=["X-Custom-Header"], # Exposed headers
max_age=3600, # Preflight request cache time
)
# Development environment configuration (allow all origins)
if os.getenv("ENVIRONMENT") == "development":
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)#HTTPS Redirect Middleware
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware
# Production environment force HTTPS
if os.getenv("ENVIRONMENT") == "production":
app.add_middleware(HTTPSRedirectMiddleware)#Trusted Host Middleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=["example.com", "*.example.com", "localhost", "127.0.0.1"]
)#GZip Compression Middleware
from fastapi.middleware.gzip import GZipMiddleware
app.add_middleware(GZipMiddleware, minimum_size=1000)#🔨 Custom Middleware
#Request ID Middleware
import uuid
from contextvars import ContextVar
# Use ContextVar to store request ID
request_id_contextvar: ContextVar[str] = ContextVar('request_id', default="")
class RequestIDMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
# Generate request ID
request_id = str(uuid.uuid4())
request_id_contextvar.set(request_id)
# Add to response headers
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.append([b"x-request-id", request_id.encode()])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
else:
await self.app(scope, receive, send)
app.add_middleware(RequestIDMiddleware)
def get_request_id() -> str:
"""Get current request's ID"""
return request_id_contextvar.get()
@app.get("/test")
async def test_endpoint():
return {"request_id": get_request_id(), "message": "Test endpoint"}#Performance Monitoring Middleware
import psutil
from typing import Dict, Any
class PerformanceMiddleware:
def __init__(self, app):
self.app = app
self.stats = {
"total_requests": 0,
"total_time": 0,
"avg_response_time": 0,
"request_count_by_method": {},
"request_count_by_status": {}
}
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
start_time = time.time()
request_method = scope["method"]
status_code = None
async def send_wrapper(message):
nonlocal status_code
if message["type"] == "http.response.start":
status_code = message["status"]
# Add performance headers
headers = list(message.get("headers", []))
# System resource usage
cpu_percent = psutil.cpu_percent()
memory_percent = psutil.virtual_memory().percent
headers.extend([
[b"x-cpu-usage", f"{cpu_percent:.1f}%".encode()],
[b"x-memory-usage", f"{memory_percent:.1f}%".encode()],
[b"x-request-count", str(self.stats["total_requests"]).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
# Update statistics
end_time = time.time()
response_time = end_time - start_time
self.stats["total_requests"] += 1
self.stats["total_time"] += response_time
self.stats["avg_response_time"] = self.stats["total_time"] / self.stats["total_requests"]
# Statistics by method
self.stats["request_count_by_method"][request_method] = \
self.stats["request_count_by_method"].get(request_method, 0) + 1
# Statistics by status code
if status_code:
self.stats["request_count_by_status"][status_code] = \
self.stats["request_count_by_status"].get(status_code, 0) + 1
performance_middleware = PerformanceMiddleware(app)
app.add_middleware(lambda: performance_middleware)
@app.get("/stats")
async def get_performance_stats():
return performance_middleware.stats#Rate Limiting Middleware
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta
class RateLimitMiddleware:
def __init__(self, app, calls: int = 100, period: int = 60):
self.app = app
self.calls = calls # Allowed number of calls
self.period = period # Time window (seconds)
self.clients = defaultdict(list) # Store client request times
self.cleanup_task = None
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
# Get client IP
client_ip = self._get_client_ip(scope)
current_time = datetime.now()
# Cleanup old requests
self._cleanup_old_requests(client_ip, current_time)
# Check rate limit
if len(self.clients[client_ip]) >= self.calls:
# Exceeds limit, return 429
response = {
"type": "http.response.start",
"status": 429,
"headers": [
[b"content-type", b"application/json"],
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", b"0"],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
]
}
await send(response)
body = {
"type": "http.response.body",
"body": b'{"error": "Rate limit exceeded", "message": "Too many requests"}'
}
await send(body)
return
# Record request time
self.clients[client_ip].append(current_time)
# Add rate limit headers
remaining = self.calls - len(self.clients[client_ip])
async def send_wrapper(message):
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
headers.extend([
[b"x-ratelimit-limit", str(self.calls).encode()],
[b"x-ratelimit-remaining", str(remaining).encode()],
[b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
])
message["headers"] = headers
await send(message)
await self.app(scope, receive, send_wrapper)
def _get_client_ip(self, scope):
# Try to get real IP from headers
headers = dict(scope.get("headers", []))
# Check common proxy headers
for header in [b"x-forwarded-for", b"x-real-ip", b"cf-connecting-ip"]:
if header in headers:
ip = headers[header].decode().split(",")[0].strip()
if ip:
return ip
# Use directly connected IP
client = scope.get("client")
return client[0] if client else "unknown"
def _cleanup_old_requests(self, client_ip, current_time):
cutoff_time = current_time - timedelta(seconds=self.period)
self.clients[client_ip] = [
req_time for req_time in self.clients[client_ip]
if req_time > cutoff_time
]
# If client has no active requests, delete record
if not self.clients[client_ip]:
del self.clients[client_ip]
# Apply rate limiting middleware: maximum 100 requests per minute
app.add_middleware(lambda: RateLimitMiddleware(app, calls=100, period=60))#Authentication Middleware
import jwt
from fastapi import HTTPException, status
class AuthenticationMiddleware:
def __init__(self, app, secret_key: str, excluded_paths: list = None):
self.app = app
self.secret_key = secret_key
self.excluded_paths = excluded_paths or ["/", "/docs", "/redoc", "/openapi.json", "/login"]
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
path = scope["path"]
method = scope["method"]
# Check if it's an excluded path
if self._is_excluded_path(path):
await self.app(scope, receive, send)
return
# Get Authorization header
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization")
if not auth_header:
await self._send_unauthorized(send, "Missing authorization header")
return
try:
# Parse Bearer token
auth_value = auth_header.decode()
if not auth_value.startswith("Bearer "):
await self._send_unauthorized(send, "Invalid authorization format")
return
token = auth_value[7:] # Remove "Bearer " prefix
# Validate JWT token
payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
# Add user info to scope
scope["user"] = payload
except jwt.ExpiredSignatureError:
await self._send_unauthorized(send, "Token has expired")
return
except jwt.InvalidTokenError:
await self._send_unauthorized(send, "Invalid token")
return
await self.app(scope, receive, send)
def _is_excluded_path(self, path: str) -> bool:
return any(path.startswith(excluded) for excluded in self.excluded_paths)
async def _send_unauthorized(self, send, message: str):
response = {
"type": "http.response.start",
"status": 401,
"headers": [[b"content-type", b"application/json"]]
}
await send(response)
body = {
"type": "http.response.body",
"body": f'{{"error": "Unauthorized", "message": "{message}"}}'.encode()
}
await send(body)
# Apply authentication middleware
SECRET_KEY = "your-secret-key"
app.add_middleware(lambda: AuthenticationMiddleware(app, SECRET_KEY))
# Login endpoint (no authentication required)
@app.post("/login")
async def login(username: str, password: str):
# Simple user authentication
if username == "admin" and password == "secret":
payload = {
"user_id": 1,
"username": username,
"exp": datetime.utcnow() + timedelta(hours=24)
}
token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
return {"access_token": token, "token_type": "bearer"}
else:
raise HTTPException(status_code=401, detail="Invalid credentials")
# Protected endpoint
@app.get("/protected")
async def protected_endpoint(request: Request):
user = request.scope.get("user")
return {"message": "This is protected", "user": user}#📊 Error Handling Middleware
#Global Exception Handling Middleware
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse
class ErrorHandlingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
try:
await self.app(scope, receive, send)
except Exception as exc:
request = Request(scope, receive)
await self._handle_error(exc, request, send)
async def _handle_error(self, exc: Exception, request: Request, send):
# Log error
logger.error(f"Unhandled exception: {exc}", exc_info=True)
# Determine response based on exception type
if isinstance(exc, HTTPException):
status_code = exc.status_code
detail = exc.detail
elif isinstance(exc, ValueError):
status_code = 400
detail = "Invalid input data"
elif isinstance(exc, FileNotFoundError):
status_code = 404
detail = "Resource not found"
else:
status_code = 500
detail = "Internal server error"
# Build error response
error_response = {
"error": {
"type": exc.__class__.__name__,
"message": detail,
"path": str(request.url.path),
"method": request.method,
"timestamp": datetime.now().isoformat()
}
}
# Development environment includes detailed error information
if os.getenv("ENVIRONMENT") == "development":
error_response["error"]["traceback"] = traceback.format_exc()
response = JSONResponse(
content=error_response,
status_code=status_code
)
await response(scope, receive, send)
app.add_middleware(ErrorHandlingMiddleware)#🔧 Middleware Best Practices
#Middleware Factory
def create_logging_middleware(log_level: str = "INFO"):
"""Create logging middleware factory function"""
def logging_middleware(app):
async def middleware(scope, receive, send):
if scope["type"] == "http":
start_time = time.time()
request = Request(scope, receive)
logger.log(
getattr(logging, log_level),
f"Request: {request.method} {request.url}"
)
async def send_wrapper(message):
if message["type"] == "http.response.start":
process_time = time.time() - start_time
logger.log(
getattr(logging, log_level),
f"Response: {message['status']} - {process_time:.4f}s"
)
await send(message)
await app(scope, receive, send_wrapper)
else:
await app(scope, receive, send)
return middleware
return logging_middleware
# Use factory to create middleware
app.add_middleware(create_logging_middleware("DEBUG"))#Conditional Middleware
def conditional_middleware(condition_func):
"""Conditional middleware decorator"""
def decorator(middleware_class):
def wrapper(app):
if condition_func():
return middleware_class(app)
else:
# If condition not met, return passthrough middleware
async def passthrough(scope, receive, send):
await app(scope, receive, send)
return passthrough
return wrapper
return decorator
@conditional_middleware(lambda: os.getenv("ENABLE_PROFILING") == "true")
class ProfilingMiddleware:
def __init__(self, app):
self.app = app
async def __call__(self, scope, receive, send):
# Performance analysis logic
import cProfile
profiler = cProfile.Profile()
profiler.enable()
await self.app(scope, receive, send)
profiler.disable()
# Save or output analysis results
profiler.dump_stats(f"/tmp/profile_{time.time()}.prof")
app.add_middleware(ProfilingMiddleware)#Middleware Configuration
from pydantic import BaseSettings
class MiddlewareSettings(BaseSettings):
enable_cors: bool = True
cors_origins: list = ["*"]
enable_gzip: bool = True
gzip_minimum_size: int = 1000
enable_rate_limiting: bool = False
rate_limit_calls: int = 100
rate_limit_period: int = 60
class Config:
env_file = ".env"
settings = MiddlewareSettings()
def setup_middleware(app: FastAPI):
"""Configure all middleware"""
# CORS
if settings.enable_cors:
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# GZip compression
if settings.enable_gzip:
app.add_middleware(
GZipMiddleware,
minimum_size=settings.gzip_minimum_size
)
# Rate limiting
if settings.enable_rate_limiting:
app.add_middleware(
lambda: RateLimitMiddleware(
app,
calls=settings.rate_limit_calls,
period=settings.rate_limit_period
)
)
# Always add middleware
app.add_middleware(RequestIDMiddleware)
app.add_middleware(ErrorHandlingMiddleware)
# Setup middleware
setup_middleware(app)#Summary
This chapter detailed FastAPI's middleware system:
- ✅ Middleware Basics: Execution flow, basic concepts
- ✅ Built-in Middleware: CORS, HTTPS redirect, GZip, etc.
- ✅ Custom Middleware: Request ID, performance monitoring, rate limiting, authentication
- ✅ Error Handling: Global exception handling middleware
- ✅ Best Practices: Middleware factories, conditional middleware, configuration management
Middleware is a powerful tool for handling cross-cutting concerns. Reasonable use can greatly enhance application functionality and maintainability.
::: tip Middleware Design Recommendations
- Keep middleware responsibilities single
- Consider middleware execution order
- Provide configuration switches to enable/disable middleware
- Pay attention to performance impact, avoid blocking operations
- Do error handling and logging well
- Write testable middleware code :::
In the next chapter, we will learn FastAPI's dependency injection system and understand how to manage application dependencies.