Skip to content

FastAPI Middleware

Overview

Middleware is an important mechanism for handling cross-cutting concerns in web applications, such as logging, performance monitoring, security checks, CORS processing, etc. Based on Starlette, FastAPI provides a powerful middleware system. This chapter will详细介绍 how to use and create middleware.

🔧 Middleware Basic Concepts

Middleware Execution Flow

mermaid
graph TB
    A[Client Request] --> B[Middleware 1 - Request Processing]
    B --> C[Middleware 2 - Request Processing]
    C --> D[Route Handler Function]
    D --> E[Middleware 2 - Response Processing]
    E --> F[Middleware 1 - Response Processing]
    F --> G[Client Response]

Middleware adopts "onion model", executing request processing in the order they are added, then response processing in reverse order.

Basic Middleware Example

python
from fastapi import FastAPI, Request, Response
import time
import logging

app = FastAPI()

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@app.middleware("http")
async def log_requests(request: Request, call_next):
    start_time = time.time()

    # Request processing
    logger.info(f"Request start: {request.method} {request.url}")

    # Call next middleware or route handler
    response = await call_next(request)

    # Response processing
    process_time = time.time() - start_time
    response.headers["X-Process-Time"] = str(process_time)
    logger.info(f"Request complete: {request.method} {request.url} - {response.status_code} - {process_time:.4f}s")

    return response

@app.get("/")
async def read_root():
    return {"message": "Hello World"}

@app.get("/slow")
async def slow_endpoint():
    await asyncio.sleep(2)  # Simulate slow operation
    return {"message": "This was slow"}

🛡️ Built-in Middleware

CORS Middleware

python
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

# Configure CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:3000", "https://myapp.com"],  # Allowed origins
    allow_credentials=True,  # Allow credentials
    allow_methods=["GET", "POST", "PUT", "DELETE"],  # Allowed methods
    allow_headers=["*"],  # Allowed headers
    expose_headers=["X-Custom-Header"],  # Exposed headers
    max_age=3600,  # Preflight request cache time
)

# Development environment configuration (allow all origins)
if os.getenv("ENVIRONMENT") == "development":
    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

HTTPS Redirect Middleware

python
from fastapi.middleware.httpsredirect import HTTPSRedirectMiddleware

# Production environment force HTTPS
if os.getenv("ENVIRONMENT") == "production":
    app.add_middleware(HTTPSRedirectMiddleware)

Trusted Host Middleware

python
from fastapi.middleware.trustedhost import TrustedHostMiddleware

app.add_middleware(
    TrustedHostMiddleware,
    allowed_hosts=["example.com", "*.example.com", "localhost", "127.0.0.1"]
)

GZip Compression Middleware

python
from fastapi.middleware.gzip import GZipMiddleware

app.add_middleware(GZipMiddleware, minimum_size=1000)

🔨 Custom Middleware

Request ID Middleware

python
import uuid
from contextvars import ContextVar

# Use ContextVar to store request ID
request_id_contextvar: ContextVar[str] = ContextVar('request_id', default="")

class RequestIDMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        if scope["type"] == "http":
            # Generate request ID
            request_id = str(uuid.uuid4())
            request_id_contextvar.set(request_id)

            # Add to response headers
            async def send_wrapper(message):
                if message["type"] == "http.response.start":
                    headers = list(message.get("headers", []))
                    headers.append([b"x-request-id", request_id.encode()])
                    message["headers"] = headers
                await send(message)

            await self.app(scope, receive, send_wrapper)
        else:
            await self.app(scope, receive, send)

app.add_middleware(RequestIDMiddleware)

def get_request_id() -> str:
    """Get current request's ID"""
    return request_id_contextvar.get()

@app.get("/test")
async def test_endpoint():
    return {"request_id": get_request_id(), "message": "Test endpoint"}

Performance Monitoring Middleware

python
import psutil
from typing import Dict, Any

class PerformanceMiddleware:
    def __init__(self, app):
        self.app = app
        self.stats = {
            "total_requests": 0,
            "total_time": 0,
            "avg_response_time": 0,
            "request_count_by_method": {},
            "request_count_by_status": {}
        }

    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        start_time = time.time()
        request_method = scope["method"]
        status_code = None

        async def send_wrapper(message):
            nonlocal status_code
            if message["type"] == "http.response.start":
                status_code = message["status"]

                # Add performance headers
                headers = list(message.get("headers", []))

                # System resource usage
                cpu_percent = psutil.cpu_percent()
                memory_percent = psutil.virtual_memory().percent

                headers.extend([
                    [b"x-cpu-usage", f"{cpu_percent:.1f}%".encode()],
                    [b"x-memory-usage", f"{memory_percent:.1f}%".encode()],
                    [b"x-request-count", str(self.stats["total_requests"]).encode()]
                ])

                message["headers"] = headers

            await send(message)

        await self.app(scope, receive, send_wrapper)

        # Update statistics
        end_time = time.time()
        response_time = end_time - start_time

        self.stats["total_requests"] += 1
        self.stats["total_time"] += response_time
        self.stats["avg_response_time"] = self.stats["total_time"] / self.stats["total_requests"]

        # Statistics by method
        self.stats["request_count_by_method"][request_method] = \
            self.stats["request_count_by_method"].get(request_method, 0) + 1

        # Statistics by status code
        if status_code:
            self.stats["request_count_by_status"][status_code] = \
                self.stats["request_count_by_status"].get(status_code, 0) + 1

performance_middleware = PerformanceMiddleware(app)
app.add_middleware(lambda: performance_middleware)

@app.get("/stats")
async def get_performance_stats():
    return performance_middleware.stats

Rate Limiting Middleware

python
import asyncio
from collections import defaultdict
from datetime import datetime, timedelta

class RateLimitMiddleware:
    def __init__(self, app, calls: int = 100, period: int = 60):
        self.app = app
        self.calls = calls  # Allowed number of calls
        self.period = period  # Time window (seconds)
        self.clients = defaultdict(list)  # Store client request times
        self.cleanup_task = None

    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        # Get client IP
        client_ip = self._get_client_ip(scope)
        current_time = datetime.now()

        # Cleanup old requests
        self._cleanup_old_requests(client_ip, current_time)

        # Check rate limit
        if len(self.clients[client_ip]) >= self.calls:
            # Exceeds limit, return 429
            response = {
                "type": "http.response.start",
                "status": 429,
                "headers": [
                    [b"content-type", b"application/json"],
                    [b"x-ratelimit-limit", str(self.calls).encode()],
                    [b"x-ratelimit-remaining", b"0"],
                    [b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
                ]
            }
            await send(response)

            body = {
                "type": "http.response.body",
                "body": b'{"error": "Rate limit exceeded", "message": "Too many requests"}'
            }
            await send(body)
            return

        # Record request time
        self.clients[client_ip].append(current_time)

        # Add rate limit headers
        remaining = self.calls - len(self.clients[client_ip])

        async def send_wrapper(message):
            if message["type"] == "http.response.start":
                headers = list(message.get("headers", []))
                headers.extend([
                    [b"x-ratelimit-limit", str(self.calls).encode()],
                    [b"x-ratelimit-remaining", str(remaining).encode()],
                    [b"x-ratelimit-reset", str(int((current_time + timedelta(seconds=self.period)).timestamp())).encode()]
                ])
                message["headers"] = headers
            await send(message)

        await self.app(scope, receive, send_wrapper)

    def _get_client_ip(self, scope):
        # Try to get real IP from headers
        headers = dict(scope.get("headers", []))

        # Check common proxy headers
        for header in [b"x-forwarded-for", b"x-real-ip", b"cf-connecting-ip"]:
            if header in headers:
                ip = headers[header].decode().split(",")[0].strip()
                if ip:
                    return ip

        # Use directly connected IP
        client = scope.get("client")
        return client[0] if client else "unknown"

    def _cleanup_old_requests(self, client_ip, current_time):
        cutoff_time = current_time - timedelta(seconds=self.period)
        self.clients[client_ip] = [
            req_time for req_time in self.clients[client_ip]
            if req_time > cutoff_time
        ]

        # If client has no active requests, delete record
        if not self.clients[client_ip]:
            del self.clients[client_ip]

# Apply rate limiting middleware: maximum 100 requests per minute
app.add_middleware(lambda: RateLimitMiddleware(app, calls=100, period=60))

Authentication Middleware

python
import jwt
from fastapi import HTTPException, status

class AuthenticationMiddleware:
    def __init__(self, app, secret_key: str, excluded_paths: list = None):
        self.app = app
        self.secret_key = secret_key
        self.excluded_paths = excluded_paths or ["/", "/docs", "/redoc", "/openapi.json", "/login"]

    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        path = scope["path"]
        method = scope["method"]

        # Check if it's an excluded path
        if self._is_excluded_path(path):
            await self.app(scope, receive, send)
            return

        # Get Authorization header
        headers = dict(scope.get("headers", []))
        auth_header = headers.get(b"authorization")

        if not auth_header:
            await self._send_unauthorized(send, "Missing authorization header")
            return

        try:
            # Parse Bearer token
            auth_value = auth_header.decode()
            if not auth_value.startswith("Bearer "):
                await self._send_unauthorized(send, "Invalid authorization format")
                return

            token = auth_value[7:]  # Remove "Bearer " prefix

            # Validate JWT token
            payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])

            # Add user info to scope
            scope["user"] = payload

        except jwt.ExpiredSignatureError:
            await self._send_unauthorized(send, "Token has expired")
            return
        except jwt.InvalidTokenError:
            await self._send_unauthorized(send, "Invalid token")
            return

        await self.app(scope, receive, send)

    def _is_excluded_path(self, path: str) -> bool:
        return any(path.startswith(excluded) for excluded in self.excluded_paths)

    async def _send_unauthorized(self, send, message: str):
        response = {
            "type": "http.response.start",
            "status": 401,
            "headers": [[b"content-type", b"application/json"]]
        }
        await send(response)

        body = {
            "type": "http.response.body",
            "body": f'{{"error": "Unauthorized", "message": "{message}"}}'.encode()
        }
        await send(body)

# Apply authentication middleware
SECRET_KEY = "your-secret-key"
app.add_middleware(lambda: AuthenticationMiddleware(app, SECRET_KEY))

# Login endpoint (no authentication required)
@app.post("/login")
async def login(username: str, password: str):
    # Simple user authentication
    if username == "admin" and password == "secret":
        payload = {
            "user_id": 1,
            "username": username,
            "exp": datetime.utcnow() + timedelta(hours=24)
        }
        token = jwt.encode(payload, SECRET_KEY, algorithm="HS256")
        return {"access_token": token, "token_type": "bearer"}
    else:
        raise HTTPException(status_code=401, detail="Invalid credentials")

# Protected endpoint
@app.get("/protected")
async def protected_endpoint(request: Request):
    user = request.scope.get("user")
    return {"message": "This is protected", "user": user}

📊 Error Handling Middleware

Global Exception Handling Middleware

python
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse

class ErrorHandlingMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return

        try:
            await self.app(scope, receive, send)
        except Exception as exc:
            request = Request(scope, receive)
            await self._handle_error(exc, request, send)

    async def _handle_error(self, exc: Exception, request: Request, send):
        # Log error
        logger.error(f"Unhandled exception: {exc}", exc_info=True)

        # Determine response based on exception type
        if isinstance(exc, HTTPException):
            status_code = exc.status_code
            detail = exc.detail
        elif isinstance(exc, ValueError):
            status_code = 400
            detail = "Invalid input data"
        elif isinstance(exc, FileNotFoundError):
            status_code = 404
            detail = "Resource not found"
        else:
            status_code = 500
            detail = "Internal server error"

        # Build error response
        error_response = {
            "error": {
                "type": exc.__class__.__name__,
                "message": detail,
                "path": str(request.url.path),
                "method": request.method,
                "timestamp": datetime.now().isoformat()
            }
        }

        # Development environment includes detailed error information
        if os.getenv("ENVIRONMENT") == "development":
            error_response["error"]["traceback"] = traceback.format_exc()

        response = JSONResponse(
            content=error_response,
            status_code=status_code
        )

        await response(scope, receive, send)

app.add_middleware(ErrorHandlingMiddleware)

🔧 Middleware Best Practices

Middleware Factory

python
def create_logging_middleware(log_level: str = "INFO"):
    """Create logging middleware factory function"""

    def logging_middleware(app):
        async def middleware(scope, receive, send):
            if scope["type"] == "http":
                start_time = time.time()
                request = Request(scope, receive)

                logger.log(
                    getattr(logging, log_level),
                    f"Request: {request.method} {request.url}"
                )

                async def send_wrapper(message):
                    if message["type"] == "http.response.start":
                        process_time = time.time() - start_time
                        logger.log(
                            getattr(logging, log_level),
                            f"Response: {message['status']} - {process_time:.4f}s"
                        )
                    await send(message)

                await app(scope, receive, send_wrapper)
            else:
                await app(scope, receive, send)

        return middleware

    return logging_middleware

# Use factory to create middleware
app.add_middleware(create_logging_middleware("DEBUG"))

Conditional Middleware

python
def conditional_middleware(condition_func):
    """Conditional middleware decorator"""
    def decorator(middleware_class):
        def wrapper(app):
            if condition_func():
                return middleware_class(app)
            else:
                # If condition not met, return passthrough middleware
                async def passthrough(scope, receive, send):
                    await app(scope, receive, send)
                return passthrough
        return wrapper
    return decorator

@conditional_middleware(lambda: os.getenv("ENABLE_PROFILING") == "true")
class ProfilingMiddleware:
    def __init__(self, app):
        self.app = app

    async def __call__(self, scope, receive, send):
        # Performance analysis logic
        import cProfile
        profiler = cProfile.Profile()
        profiler.enable()

        await self.app(scope, receive, send)

        profiler.disable()
        # Save or output analysis results
        profiler.dump_stats(f"/tmp/profile_{time.time()}.prof")

app.add_middleware(ProfilingMiddleware)

Middleware Configuration

python
from pydantic import BaseSettings

class MiddlewareSettings(BaseSettings):
    enable_cors: bool = True
    cors_origins: list = ["*"]
    enable_gzip: bool = True
    gzip_minimum_size: int = 1000
    enable_rate_limiting: bool = False
    rate_limit_calls: int = 100
    rate_limit_period: int = 60

    class Config:
        env_file = ".env"

settings = MiddlewareSettings()

def setup_middleware(app: FastAPI):
    """Configure all middleware"""

    # CORS
    if settings.enable_cors:
        app.add_middleware(
            CORSMiddleware,
            allow_origins=settings.cors_origins,
            allow_credentials=True,
            allow_methods=["*"],
            allow_headers=["*"],
        )

    # GZip compression
    if settings.enable_gzip:
        app.add_middleware(
            GZipMiddleware,
            minimum_size=settings.gzip_minimum_size
        )

    # Rate limiting
    if settings.enable_rate_limiting:
        app.add_middleware(
            lambda: RateLimitMiddleware(
                app,
                calls=settings.rate_limit_calls,
                period=settings.rate_limit_period
            )
        )

    # Always add middleware
    app.add_middleware(RequestIDMiddleware)
    app.add_middleware(ErrorHandlingMiddleware)

# Setup middleware
setup_middleware(app)

Summary

This chapter detailed FastAPI's middleware system:

  • Middleware Basics: Execution flow, basic concepts
  • Built-in Middleware: CORS, HTTPS redirect, GZip, etc.
  • Custom Middleware: Request ID, performance monitoring, rate limiting, authentication
  • Error Handling: Global exception handling middleware
  • Best Practices: Middleware factories, conditional middleware, configuration management

Middleware is a powerful tool for handling cross-cutting concerns. Reasonable use can greatly enhance application functionality and maintainability.

Middleware Design Recommendations

  • Keep middleware responsibilities single
  • Consider middleware execution order
  • Provide configuration switches to enable/disable middleware
  • Pay attention to performance impact, avoid blocking operations
  • Do error handling and logging well
  • Write testable middleware code

In the next chapter, we will learn FastAPI's dependency injection system and understand how to manage application dependencies.

Content is for learning and research only.